diff --git a/src/services/policy.ts b/src/services/policy.ts index 68e9611..5e06292 100644 --- a/src/services/policy.ts +++ b/src/services/policy.ts @@ -156,18 +156,31 @@ const extractEvents = ( }; /** - * recursively returns an array of contract called - * @param {*} trace - * @returns Arrray of string address + * Recursively returns an array of contract addresses. + * @param {FunctionInvocation} trace + * @param {number} currentDepth - The current depth in the tree. + * @param {number} maxDepth - The maximum depth to traverse in the tree. + * @returns {Array} An array of contract addresses. */ -const extractContractAddresses = (trace: FunctionInvocation): Array => { +const extractContractAddresses = ( + trace: FunctionInvocation, + currentDepth = 0, + maxDepth = Infinity +): Array => { + // Stop the recursion if the current depth exceeds the max depth + if (currentDepth > maxDepth) { + return []; + } + const contractAddresses: Array = trace.contract_address ? [trace.contract_address] : []; if (trace.internal_calls.length) { trace.internal_calls.forEach((internalCall: FunctionInvocation) => { - contractAddresses.push(...extractContractAddresses(internalCall)); + contractAddresses.push( + ...extractContractAddresses(internalCall, currentDepth + 1, maxDepth) + ); }); } @@ -328,8 +341,9 @@ const verifyPolicyWithTrace = ( ? extractEvents(trace.function_invocation) : []; const extractedAddresses = trace.function_invocation - ? extractContractAddresses(trace.function_invocation) + ? extractContractAddresses(trace.function_invocation, 0, 2) : []; + const userAddresses: string[] = extractAllowlistAddresses(policySanitized); // Filter extractedAddresses to only include addresses that are not in userAddressesSet const userAddressesSet = new Set(userAddresses).add(account); @@ -393,6 +407,7 @@ const findNFTIds = (policy: Policy, event: FunctionInvocation): boolean => { export default { verifyPolicy, verifyPolicyWithTrace, + extractContractAddresses, encodePolicy, getPolicies, }; diff --git a/test/policy.test.ts b/test/policy.test.ts index 30ef4fc..4151211 100644 --- a/test/policy.test.ts +++ b/test/policy.test.ts @@ -271,6 +271,20 @@ describe('policy detection tests', () => { "0x373c71f077b96cbe7a57225cd503d29cadb0056ed741a058094234d82de2f9", // Alpha Road: Pool Factory "0x61fdcf831f23d070b26a4fdc9d43c2fbba1928a529f51b5335cd7b738f97945" // Alpha Road: ETH/arfBTC LP */ + test('allowlist', async () => { + const trace = JSON.parse(readFileSync('test/txTrace1.json', 'utf8')); + const addresses = await policyService.extractContractAddresses( + trace.function_invocation, + 0, + 2 + ); + expect(addresses).toStrictEqual([ + '0x71dc40f7a57befa889f77d9c912523843a7fc978f4ee422f1b4573a80108b73', + '0x72df4dc5b6c4df72e4288857317caf2ce9da166ab8719ab8306516a2fddfff7', + '0x49d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7', + '0x4aec73f0611a9be0524e7ef21ab1679bdf9c97dc7d72614f15373d431226b6a', + ]); + }); test('policy pass', async () => { const trace = JSON.parse(readFileSync('test/txTrace1.json', 'utf8')); const policy = JSON.parse(readFileSync('test/policyERC20.json', 'utf8'));