Skip to content

Commit edb87b5

Browse files
Aragaharamichprev
authored andcommitted
✨ Add selector clashes detection to proxy detector
1 parent 7ca2e08 commit edb87b5

File tree

4 files changed

+157
-33
lines changed

4 files changed

+157
-33
lines changed

tests/detectors_sources/proxy_contract.sol tests/detectors_sources/proxy_contract_selector_clashes.sol

+40-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ contract Slots {
77
bytes32 internal constant _OWNER_SLOT = 0x02016836a56b71f0d02689e69e326f4f4c1b9057164ef592671cf0d37c8040c0;
88
}
99

10-
abstract contract Legit_Proxy_1 {
10+
abstract contract AProxy {
1111
function _implementation() internal view virtual returns (address);
1212

1313
function _delegate(address implementation) internal virtual {
@@ -37,15 +37,19 @@ abstract contract Legit_Proxy_1 {
3737
function _beforeFallback() internal virtual {}
3838
}
3939

40-
contract Bug_Proxy_1 is Legit_Proxy_1, Slots {
40+
contract Proxy is AProxy, Slots {
4141
function _implementation() internal view override returns (address implementation_) {
4242
assembly {
4343
implementation_ := sload(_IMPLEMENTATION_SLOT)
4444
}
4545
}
46+
47+
function bug_clash(address addr) public {
48+
int a = 1;
49+
}
4650
}
4751

48-
abstract contract Legit_Proxy_2 {
52+
abstract contract AProxy2 {
4953
function _implementation() internal view virtual returns (address);
5054

5155
function _delegate(address implementation) internal virtual {
@@ -248,17 +252,48 @@ library Address {
248252
}
249253
}
250254

251-
contract Bug_Impl_1 is Slots {
255+
contract Proxy2 is AProxy, Slots {
256+
function _implementation() internal view override returns (address implementation_) {
257+
assembly {
258+
implementation_ := sload(_IMPLEMENTATION_SLOT)
259+
}
260+
}
261+
262+
function legit_clash(address addr) public {
263+
int a = 1;
264+
}
265+
}
266+
267+
contract Impl1 is Slots {
252268
function _setImplementation(address newImplementation) private {
253269
require(Address.isContract(newImplementation), "ERC1967: new implementation is not a contract");
254270
StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation;
255271
}
272+
273+
function bug_clash(address addr) public {
274+
_setImplementation(addr);
275+
}
256276
}
257277

258-
contract Legit_Impl_2 is Slots {
278+
contract Impl2 is Slots {
279+
function _setImplementation(address newImplementation) private {
280+
require(Address.isContract(newImplementation), "ERC1967: new implementation is not a contract");
281+
StorageSlot.getAddressSlot(_IMPLEMENTATION_SLOT).value = newImplementation;
282+
}
283+
284+
function legit_clash(address addr, uint a) public {
285+
_setImplementation(addr);
286+
}
287+
}
288+
289+
contract Impl3 is Slots {
259290
function _setImplementation(address newImplementation) private {
260291
require(Address.isContract(newImplementation), "ERC1967: new implementation is not a contract");
261292
StorageSlot.getAddressSlot(_OWNER_SLOT).value = newImplementation;
262293
}
294+
295+
function legit_clash(address addr) public {
296+
_setImplementation(addr);
297+
}
263298
}
264299

tests/test_detectors.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import logging
23
import shutil
34
import tempfile
45
from pathlib import Path, PurePath
@@ -104,20 +105,24 @@ def check_detected_in_functions(config, source_path):
104105
detections = detect(config, source_units)
105106

106107
detections_fns = [_get_function_def(det.result.ir_node) for det in detections]
107-
detections_fn_names = [fn.name for fn in detections_fns if fn is not None]
108+
detections_fn_names = [fn.canonical_name for fn in detections_fns if fn is not None]
108109

110+
detected_something = False
109111
for contract in source_file.contracts:
110112
for fn in contract.functions:
111113
if fn.name.startswith("legit_"):
112-
assert fn.name not in detections_fn_names
114+
assert fn.canonical_name not in detections_fn_names
113115
elif fn.name.startswith("bug_"):
114-
assert fn.name in detections_fn_names
116+
detected_something = True
117+
assert fn.canonical_name in detections_fn_names
115118

116119
for fn in source_file.functions:
117120
if fn.name.startswith("legit_"):
118-
assert fn.name not in detections_fn_names
121+
assert fn.canonical_name not in detections_fn_names
119122
elif fn.name.startswith("bug_"):
120-
assert fn.name in detections_fn_names
123+
detected_something = True
124+
assert fn.canonical_name in detections_fn_names
125+
assert detected_something
121126

122127

123128
def check_detected_in_contracts(config, source_path):
@@ -130,11 +135,14 @@ def check_detected_in_contracts(config, source_path):
130135
contract.name for contract in detections_contracts if contract is not None
131136
]
132137

138+
detected_something = False
133139
for cf in source_file.contracts:
134140
if cf.name.startswith("Legit_"):
135141
assert cf.name not in detections_contract_names
136142
elif cf.name.startswith("Bug_"):
143+
detected_something = True
137144
assert cf.name in detections_contract_names
145+
assert detected_something
138146

139147

140148
class TestNoReturnDetector:
@@ -214,16 +222,15 @@ def test_sources(self, config, tmp_path):
214222
test_file = "not_used.sol"
215223
test_source_path = tmp_path / test_file
216224
shutil.copyfile(SOURCES_PATH / test_file, test_source_path)
217-
check_detected_in_functions(config, test_source_path)
218225
check_detected_in_contracts(config, test_source_path)
219226

220227

221-
class TestProxyContract:
228+
class TestProxyContractSelectorClashes:
222229
@pytest.fixture
223230
def config(self, tmp_path) -> WokeConfig:
224231
config_dict = {
225232
"compiler": {"solc": {"include_paths": ["./node_modules"]}},
226-
"detectors": {"only": {"proxy-contract"}},
233+
"detectors": {"only": {"proxy-contract-selector-clashes"}},
227234
}
228235
return WokeConfig.fromdict(
229236
config_dict,
@@ -232,7 +239,16 @@ def config(self, tmp_path) -> WokeConfig:
232239
)
233240

234241
def test_sources(self, config, tmp_path):
235-
test_file = "proxy_contract.sol"
242+
test_file = "proxy_contract_selector_clashes.sol"
236243
test_source_path = tmp_path / test_file
237244
shutil.copyfile(SOURCES_PATH / test_file, test_source_path)
238-
check_detected_in_contracts(config, test_source_path)
245+
246+
_, source_units = compile_project(test_source_path, config)
247+
detections = detect(config, source_units)
248+
logging.error(detections)
249+
detections_fns = [_get_function_def(det.result.ir_node) for det in detections]
250+
detections_fn_names = [
251+
fn.canonical_name for fn in detections_fns if fn is not None
252+
]
253+
254+
assert "Proxy.bug_clash" in detections_fn_names

woke/analysis/detectors/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .overflow_calldata_tuple_reencoding_bug import (
1515
OverflowCalldataTupleReencodingBugDetector,
1616
)
17-
from .proxy_contract import ProxyContractDetector
17+
from .proxy_contract_selector_clashes import ProxyContractSelectorClashDetector
1818
from .reentrancy import ReentrancyDetector
1919
from .unchecked_return_value import UncheckedFunctionReturnValueDetector
2020
from .unsafe_delegatecall import UnsafeDelegatecallDetector

woke/analysis/detectors/proxy_contract.py woke/analysis/detectors/proxy_contract_selector_clashes.py

+90-17
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from functools import lru_cache
2-
from typing import List, Optional, Set, Tuple, Union
2+
from typing import Dict, List, Optional, Set, Tuple, Union
33

44
from woke.analysis.detectors import DetectorAbc, DetectorResult, detector
55
from woke.analysis.detectors.utils import (
66
get_function_implementations,
77
pair_function_call_arguments,
88
)
9-
from woke.ast.enums import FunctionKind, GlobalSymbolsEnum, LiteralKind
9+
from woke.ast.enums import FunctionKind, GlobalSymbolsEnum, LiteralKind, Visibility
1010
from woke.ast.ir.abc import IrAbc
1111
from woke.ast.ir.declaration.abc import DeclarationAbc
1212
from woke.ast.ir.declaration.contract_definition import ContractDefinition
@@ -398,25 +398,70 @@ def detect_slot_usage(fn: FunctionDefinition, visited=None) -> List[DetectorResu
398398
return dets
399399

400400

401-
@detector(-1030, "proxy-contract")
402-
class ProxyContractDetector(DetectorAbc):
401+
def detect_selector_clashes(
402+
proxy_contract: ContractDefinition,
403+
impl_contract: ContractDefinition,
404+
proxy_detection: DetectorResult,
405+
impl_detection: DetectorResult,
406+
) -> List[DetectorResult]:
407+
fn_whitelist = ["implementation"]
408+
409+
proxy_selectors = {}
410+
for c in proxy_contract.linearized_base_contracts:
411+
for f in c.functions + c.declared_variables:
412+
if f.function_selector is not None:
413+
proxy_selectors[f.function_selector] = f
414+
415+
impl_selectors = {}
416+
for c in impl_contract.linearized_base_contracts:
417+
for f in c.functions + c.declared_variables:
418+
if f.function_selector is not None:
419+
impl_selectors[f.function_selector] = f
420+
421+
clashes = []
422+
for proxy_sel, proxy_fn in proxy_selectors.items():
423+
if isinstance(proxy_fn, FunctionDefinition) and proxy_fn.name in fn_whitelist:
424+
continue
425+
if proxy_sel in impl_selectors:
426+
clashes.append(
427+
DetectorResult(
428+
proxy_fn,
429+
"Detected selector clash with implementation contract",
430+
related_info=(
431+
DetectorResult(
432+
impl_selectors[proxy_sel],
433+
"Implementation function with same selector",
434+
related_info=(impl_detection,),
435+
),
436+
proxy_detection,
437+
),
438+
)
439+
)
440+
return clashes
441+
442+
443+
@detector(-1030, "proxy-contract-selector-clashes")
444+
class ProxyContractSelectorClashDetector(DetectorAbc):
403445
"""
404-
Detects proxy contracts based on fallback function and usage of slot variables and
446+
Detects selector clashes in proxy and implementation contracts.
447+
Proxy contracts are detected based on fallback function and usage of slot variables and
405448
implementation contracts that use same slots as proxy contracts
406449
"""
407450

408451
_proxy_detections: Set[Tuple[ContractDefinition, DetectorResult]]
409452
_proxy_associated_contracts: Set[ContractDefinition]
410453
_implementation_slots_detections: Set[Tuple[ContractDefinition, DetectorResult]]
411-
_implementation_slots: Set[VariableDeclaration]
454+
_implementation_slots: Dict[VariableDeclaration, List[ContractDefinition]]
412455

413456
def __init__(self):
414457
self._proxy_detections: Set[Tuple[ContractDefinition, DetectorResult]] = set()
415458
self._proxy_associated_contracts: Set[ContractDefinition] = set()
416459
self._implementation_slots_detections: Set[
417460
Tuple[ContractDefinition, DetectorResult]
418461
] = set()
419-
self._implementation_slots: Set[VariableDeclaration] = set()
462+
self._implementation_slots: Dict[
463+
VariableDeclaration, List[ContractDefinition]
464+
] = {}
420465

421466
def report(self) -> List[DetectorResult]:
422467
detections = []
@@ -428,9 +473,10 @@ def report(self) -> List[DetectorResult]:
428473
continue
429474
base_proxy_contracts.add(c)
430475

476+
proxy_detections = {}
431477
for (contract, det) in self._proxy_detections:
432478
if contract not in base_proxy_contracts:
433-
detections.append(det)
479+
proxy_detections[contract] = det
434480

435481
base_impl_contracts = set()
436482
for (contract, _) in self._implementation_slots_detections:
@@ -439,22 +485,37 @@ def report(self) -> List[DetectorResult]:
439485
continue
440486
base_impl_contracts.add(c)
441487

488+
checked_pairs = set()
442489
impl_detections_contracts = set()
443490
for (contract, det) in self._implementation_slots_detections:
491+
impl_slot = get_last_detection_node(det).parent
444492
if (
445493
contract not in self._proxy_associated_contracts
446494
and contract not in base_impl_contracts
447-
and get_last_detection_node(det).parent in self._implementation_slots
495+
and impl_slot in self._implementation_slots
448496
and contract not in impl_detections_contracts
449497
):
450-
impl_detections_contracts.add(contract)
451-
detections.append(
452-
DetectorResult(
498+
for proxy_contract in self._implementation_slots[impl_slot]:
499+
if (
500+
proxy_contract,
501+
contract,
502+
) in checked_pairs or proxy_contract not in proxy_detections:
503+
continue
504+
checked_pairs.add((proxy_contract, contract))
505+
impl_det = DetectorResult(
453506
contract,
454-
"Detected implementation contract with slot used in proxy contract",
507+
f"Detected implementation contract with slot used in proxy contract {proxy_contract.name}",
455508
related_info=(det,),
456509
)
457-
)
510+
511+
dets = detect_selector_clashes(
512+
proxy_contract,
513+
contract,
514+
proxy_detections[proxy_contract],
515+
impl_det,
516+
)
517+
if len(dets) > 0:
518+
detections.extend(dets)
458519
return list(detections)
459520

460521
def visit_contract_definition(self, node: ContractDefinition):
@@ -478,10 +539,22 @@ def visit_contract_definition(self, node: ContractDefinition):
478539
)
479540
for det in dets:
480541
last_det_node = get_last_detection_node(det)
481-
if isinstance(
482-
last_det_node.parent, VariableDeclaration
542+
if (
543+
isinstance(
544+
last_det_node.parent, VariableDeclaration
545+
)
546+
and last_det_node.parent is not None
483547
):
484-
self._implementation_slots.add(last_det_node.parent)
548+
if (
549+
last_det_node.parent
550+
not in self._implementation_slots
551+
):
552+
self._implementation_slots[
553+
last_det_node.parent
554+
] = []
555+
self._implementation_slots[
556+
last_det_node.parent
557+
].append(node)
485558
self._proxy_associated_contracts.add(node)
486559
for bc in node.linearized_base_contracts:
487560
self._proxy_associated_contracts.add(bc)

0 commit comments

Comments
 (0)