1
1
from functools import lru_cache
2
- from typing import List , Optional , Set , Tuple , Union
2
+ from typing import Dict , List , Optional , Set , Tuple , Union
3
3
4
4
from woke .analysis .detectors import DetectorAbc , DetectorResult , detector
5
5
from woke .analysis .detectors .utils import (
6
6
get_function_implementations ,
7
7
pair_function_call_arguments ,
8
8
)
9
- from woke .ast .enums import FunctionKind , GlobalSymbolsEnum , LiteralKind
9
+ from woke .ast .enums import FunctionKind , GlobalSymbolsEnum , LiteralKind , Visibility
10
10
from woke .ast .ir .abc import IrAbc
11
11
from woke .ast .ir .declaration .abc import DeclarationAbc
12
12
from woke .ast .ir .declaration .contract_definition import ContractDefinition
@@ -398,25 +398,70 @@ def detect_slot_usage(fn: FunctionDefinition, visited=None) -> List[DetectorResu
398
398
return dets
399
399
400
400
401
- @detector (- 1030 , "proxy-contract" )
402
- class ProxyContractDetector (DetectorAbc ):
401
+ def detect_selector_clashes (
402
+ proxy_contract : ContractDefinition ,
403
+ impl_contract : ContractDefinition ,
404
+ proxy_detection : DetectorResult ,
405
+ impl_detection : DetectorResult ,
406
+ ) -> List [DetectorResult ]:
407
+ fn_whitelist = ["implementation" ]
408
+
409
+ proxy_selectors = {}
410
+ for c in proxy_contract .linearized_base_contracts :
411
+ for f in c .functions + c .declared_variables :
412
+ if f .function_selector is not None :
413
+ proxy_selectors [f .function_selector ] = f
414
+
415
+ impl_selectors = {}
416
+ for c in impl_contract .linearized_base_contracts :
417
+ for f in c .functions + c .declared_variables :
418
+ if f .function_selector is not None :
419
+ impl_selectors [f .function_selector ] = f
420
+
421
+ clashes = []
422
+ for proxy_sel , proxy_fn in proxy_selectors .items ():
423
+ if isinstance (proxy_fn , FunctionDefinition ) and proxy_fn .name in fn_whitelist :
424
+ continue
425
+ if proxy_sel in impl_selectors :
426
+ clashes .append (
427
+ DetectorResult (
428
+ proxy_fn ,
429
+ "Detected selector clash with implementation contract" ,
430
+ related_info = (
431
+ DetectorResult (
432
+ impl_selectors [proxy_sel ],
433
+ "Implementation function with same selector" ,
434
+ related_info = (impl_detection ,),
435
+ ),
436
+ proxy_detection ,
437
+ ),
438
+ )
439
+ )
440
+ return clashes
441
+
442
+
443
+ @detector (- 1030 , "proxy-contract-selector-clashes" )
444
+ class ProxyContractSelectorClashDetector (DetectorAbc ):
403
445
"""
404
- Detects proxy contracts based on fallback function and usage of slot variables and
446
+ Detects selector clashes in proxy and implementation contracts.
447
+ Proxy contracts are detected based on fallback function and usage of slot variables and
405
448
implementation contracts that use same slots as proxy contracts
406
449
"""
407
450
408
451
_proxy_detections : Set [Tuple [ContractDefinition , DetectorResult ]]
409
452
_proxy_associated_contracts : Set [ContractDefinition ]
410
453
_implementation_slots_detections : Set [Tuple [ContractDefinition , DetectorResult ]]
411
- _implementation_slots : Set [VariableDeclaration ]
454
+ _implementation_slots : Dict [VariableDeclaration , List [ ContractDefinition ] ]
412
455
413
456
def __init__ (self ):
414
457
self ._proxy_detections : Set [Tuple [ContractDefinition , DetectorResult ]] = set ()
415
458
self ._proxy_associated_contracts : Set [ContractDefinition ] = set ()
416
459
self ._implementation_slots_detections : Set [
417
460
Tuple [ContractDefinition , DetectorResult ]
418
461
] = set ()
419
- self ._implementation_slots : Set [VariableDeclaration ] = set ()
462
+ self ._implementation_slots : Dict [
463
+ VariableDeclaration , List [ContractDefinition ]
464
+ ] = {}
420
465
421
466
def report (self ) -> List [DetectorResult ]:
422
467
detections = []
@@ -428,9 +473,10 @@ def report(self) -> List[DetectorResult]:
428
473
continue
429
474
base_proxy_contracts .add (c )
430
475
476
+ proxy_detections = {}
431
477
for (contract , det ) in self ._proxy_detections :
432
478
if contract not in base_proxy_contracts :
433
- detections . append ( det )
479
+ proxy_detections [ contract ] = det
434
480
435
481
base_impl_contracts = set ()
436
482
for (contract , _ ) in self ._implementation_slots_detections :
@@ -439,22 +485,37 @@ def report(self) -> List[DetectorResult]:
439
485
continue
440
486
base_impl_contracts .add (c )
441
487
488
+ checked_pairs = set ()
442
489
impl_detections_contracts = set ()
443
490
for (contract , det ) in self ._implementation_slots_detections :
491
+ impl_slot = get_last_detection_node (det ).parent
444
492
if (
445
493
contract not in self ._proxy_associated_contracts
446
494
and contract not in base_impl_contracts
447
- and get_last_detection_node ( det ). parent in self ._implementation_slots
495
+ and impl_slot in self ._implementation_slots
448
496
and contract not in impl_detections_contracts
449
497
):
450
- impl_detections_contracts .add (contract )
451
- detections .append (
452
- DetectorResult (
498
+ for proxy_contract in self ._implementation_slots [impl_slot ]:
499
+ if (
500
+ proxy_contract ,
501
+ contract ,
502
+ ) in checked_pairs or proxy_contract not in proxy_detections :
503
+ continue
504
+ checked_pairs .add ((proxy_contract , contract ))
505
+ impl_det = DetectorResult (
453
506
contract ,
454
- "Detected implementation contract with slot used in proxy contract" ,
507
+ f "Detected implementation contract with slot used in proxy contract { proxy_contract . name } " ,
455
508
related_info = (det ,),
456
509
)
457
- )
510
+
511
+ dets = detect_selector_clashes (
512
+ proxy_contract ,
513
+ contract ,
514
+ proxy_detections [proxy_contract ],
515
+ impl_det ,
516
+ )
517
+ if len (dets ) > 0 :
518
+ detections .extend (dets )
458
519
return list (detections )
459
520
460
521
def visit_contract_definition (self , node : ContractDefinition ):
@@ -478,10 +539,22 @@ def visit_contract_definition(self, node: ContractDefinition):
478
539
)
479
540
for det in dets :
480
541
last_det_node = get_last_detection_node (det )
481
- if isinstance (
482
- last_det_node .parent , VariableDeclaration
542
+ if (
543
+ isinstance (
544
+ last_det_node .parent , VariableDeclaration
545
+ )
546
+ and last_det_node .parent is not None
483
547
):
484
- self ._implementation_slots .add (last_det_node .parent )
548
+ if (
549
+ last_det_node .parent
550
+ not in self ._implementation_slots
551
+ ):
552
+ self ._implementation_slots [
553
+ last_det_node .parent
554
+ ] = []
555
+ self ._implementation_slots [
556
+ last_det_node .parent
557
+ ].append (node )
485
558
self ._proxy_associated_contracts .add (node )
486
559
for bc in node .linearized_base_contracts :
487
560
self ._proxy_associated_contracts .add (bc )
0 commit comments