Skip to content

Commit

Permalink
fix: nested lookahead (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Sep 1, 2023
1 parent e3f2cac commit 46f45f2
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 55 deletions.
63 changes: 37 additions & 26 deletions evm_trace/geth.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,37 +91,46 @@ def create_trace_frames(data: Iterator[Dict]) -> Iterator[TraceFrame]:
for frame in frames:
frame_obj = TraceFrame(**frame)

if CallType.CREATE.value in frame.get("op", ""):
if CallType.CREATE.value in frame_obj.op:
# Look ahead to find the address.

create_frames = [frame_obj]
start_depth = frame.get("depth", 0)
for next_frame in frames:
next_frame_obj = TraceFrame.parse_obj(next_frame)
depth = next_frame_obj.depth

if depth <= start_depth:
# Extract the address for the original CREATE using
# the first frame after the CREATE with an equal depth.
if len(next_frame_obj.stack) > 0:
raw_addr = HexBytes(next_frame_obj.stack[-1][-40:])
try:
frame_obj.contract_address = HexBytes(to_address(raw_addr))
except Exception:
# Potentially, a transaction was made with poor data.
frame_obj.contract_address = raw_addr

create_frames.append(next_frame_obj)
yield from create_frames
break

elif depth > start_depth:
create_frames.append(next_frame_obj)
create_frames = _get_create_frames(frame_obj, frames)
yield from create_frames

else:
yield TraceFrame(**frame)


def _get_create_frames(frame: TraceFrame, frames: Iterator[Dict]) -> List[TraceFrame]:
create_frames = [frame]
start_depth = frame.depth
for next_frame in frames:
next_frame_obj = TraceFrame.parse_obj(next_frame)
depth = next_frame_obj.depth

if CallType.CREATE.value in next_frame_obj.op:
# Handle CREATE within a CREATE.
create_frames.extend(_get_create_frames(next_frame_obj, frames))

elif depth <= start_depth:
# Extract the address for the original CREATE using
# the first frame after the CREATE with an equal depth.
if len(next_frame_obj.stack) > 0:
raw_addr = HexBytes(next_frame_obj.stack[-1][-40:])
try:
frame.contract_address = HexBytes(to_address(raw_addr))
except Exception:
# Potentially, a transaction was made with poor data.
frame.contract_address = raw_addr

create_frames.append(next_frame_obj)
break

elif depth > start_depth:
create_frames.append(next_frame_obj)

return create_frames


def get_calltree_from_geth_call_trace(data: Dict) -> CallTreeNode:
"""
Creates a CallTreeNode from a given transaction call trace.
Expand Down Expand Up @@ -271,7 +280,9 @@ def _create_node(
for subcall in node_kwargs.get("calls", [])[::-1]:
if subcall.call_type in (CallType.CREATE, CallType.CREATE2):
subcall.address = HexBytes(to_address(frame.stack[-1][-40:]))
subcall.calldata = frame.memory.get(frame.stack[-4], frame.stack[-5])
if len(frame.stack) >= 5:
subcall.calldata = frame.memory.get(frame.stack[-4], frame.stack[-5])

break

if frame.op in [x.value for x in CALL_OPCODES]:
Expand Down
3 changes: 1 addition & 2 deletions tests/data/geth/create2_structlogs.json

Large diffs are not rendered by default.

44 changes: 17 additions & 27 deletions tests/test_geth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest
from ethpm_types import HexBytes
from pydantic import ValidationError
Expand Down Expand Up @@ -134,31 +136,13 @@ def test_get_call_tree_from_create2_struct_logs(geth_create2_trace_frames):
gas_limit=30000000,
calldata=HexBytes(calldata),
)
expected = f"""
CALL: {address}.<{calldata[:10]}>
└── CREATE2: 0x7c23b43594428A657718713FF246C609EeDDfAFf
""".strip()
assert len(node.calls) == 1
assert repr(node) == expected.strip()

expected_value = 123
expected_calldata = HexBytes(
"0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
"bcf7fffd8b256ec51a36782a52d0c34f6474d95100000000000000000000000000000000000000000000000000"
"000000000000000000000000000000000000000000000000000000000000000000000000000003000000000000"
"000000000000000000000000000000000000000000000000000360206101356000396000516000556101166100"
"1f61000039610116610000f36003361161000c576100fe565b60003560e01c346101045763425ace5281186100"
"3957600436106101045760036040526001606052610077565b6350144002811861005c57602436106101045760"
"04356040526001606052610077565b6318b30cb781186100a1576044361061010457604060046040375b604051"
"15610104576040516060518082018281106101045790509050600055600160805260206080f35b63c7dd9abe81"
"186100bf576004361061010457600160405260206040f35b6327871d9981186100dd5760043610610104576001"
"60405260206040f35b63f9bd55cc81186100fc57600436106101045760005460405260206040f35b505b600060"
"00fd5b600080fda165767970657283000307000b00000000000000000000000000000000000000000000000000"
"00000000000003"
assert len(node.calls) == 2
actual = repr(node)[:120]
pattern = re.compile(
rf".*\s*CALL: {address}\."
rf"<{calldata[:10]}>\s*├── CREATE2: 0x[a-fA-F0-9]{{40}}[\s└─├\w:.<?>]*"
)
create_node = node.calls[0]
assert create_node.value == expected_value
assert create_node.calldata.startswith(expected_calldata)
assert pattern.match(actual), f"actual: {actual}, pattern: {str(pattern)}"


def test_create_trace_frames_from_geth_create2_struct_logs(
Expand All @@ -168,7 +152,13 @@ def test_create_trace_frames_from_geth_create2_struct_logs(
assert len(frames) == len(geth_create2_trace_frames)
assert frames != geth_create2_trace_frames

assert "CREATE2" in [f.op for f in frames]
create2_found = False
for frame in frames:
if frame.op == "CREATE2":
assert frame.address == HexBytes("0x7c23b43594428a657718713ff246c609eeddfaff")
if frame.op.startswith("CREATE"):
assert frame.address
address = frame.address.hex()
assert address.startswith("0x")
assert len(address) == 42
create2_found = create2_found or frame.op == "CREATE2"

assert create2_found

0 comments on commit 46f45f2

Please sign in to comment.