1
- import json
2
1
import importlib
2
+ import json
3
3
from datetime import datetime , timedelta
4
4
from decimal import Decimal
5
5
from functools import partial
6
- from typing import Any , Dict , Iterable , Iterator , List , Optional , Union , cast , ClassVar
6
+ from pathlib import Path
7
+ from typing import Any , ClassVar , Dict , Iterable , Iterator , List , Optional , Union , cast
7
8
8
9
from ape .api import ReceiptAPI
9
10
from ape .contracts .base import ContractInstance , ContractTransactionHandler
10
11
from ape .exceptions import (
11
12
CompilerError ,
12
13
ContractLogicError ,
14
+ ContractNotFoundError ,
13
15
DecodingError ,
14
16
ProjectError ,
15
- ContractNotFoundError ,
16
17
)
17
- from ape .types import AddressType , ContractLog , HexBytes
18
+ from ape .types import AddressType , ContractLog
18
19
from ape .utils import BaseInterfaceModel , cached_property
19
20
from ethpm_types import ContractType , PackageManifest
20
- from pydantic import ValidationError , validator
21
+ from hexbytes import HexBytes
22
+ from pydantic import ValidationError , ValidationInfo , field_validator
21
23
22
24
from .exceptions import (
23
25
FundsNotClaimable ,
@@ -48,7 +50,7 @@ def __eq__(self, other: Any) -> bool:
48
50
# Try __eq__ from the other side.
49
51
return NotImplemented
50
52
51
- def validate (self , creator , token , amount_per_second , reason ) -> bool :
53
+ def validate (self , creator , token , amount_per_second , reason ) -> bool : # type: ignore
52
54
try :
53
55
self .contract .validate .call (creator , token , amount_per_second , reason )
54
56
return True
@@ -63,14 +65,16 @@ def validate(self, creator, token, amount_per_second, reason) -> bool:
63
65
class StreamManager (BaseInterfaceModel ):
64
66
address : AddressType
65
67
contract_type : Optional [ContractType ] = None
66
- _local_contracts : ClassVar [Dict [str , ContractType ]]
68
+ _local_contracts : ClassVar [Dict [str , ContractType ]] = dict ()
67
69
68
- @validator ("address" , pre = True )
70
+ @field_validator ("address" , mode = "before" )
71
+ @classmethod
69
72
def normalize_address (cls , value : Any ) -> AddressType :
70
73
return cls .conversion_manager .convert (value , AddressType )
71
74
72
- @validator ("contract_type" , pre = True , always = True )
73
- def fetch_contract_type (cls , value : Any , values : Dict [str , Any ]) -> ContractType :
75
+ @field_validator ("contract_type" , mode = "before" )
76
+ @classmethod
77
+ def fetch_contract_type (cls , value : Any , info : ValidationInfo ) -> Optional [ContractType ]:
74
78
# 0. If pre-loaded, default to that type
75
79
if value :
76
80
return value
@@ -86,19 +90,24 @@ def fetch_contract_type(cls, value: Any, values: Dict[str, Any]) -> ContractType
86
90
87
91
# 2. If contract cache has it, use that
88
92
try :
89
- if values .get ("address" ) and (
90
- contract_type := cls .chain_manager .contracts .get (values ["address" ])
93
+ if info . data .get ("address" ) and (
94
+ contract_type := cls .chain_manager .contracts .get (info . data ["address" ])
91
95
):
92
96
return contract_type
93
97
94
98
except Exception :
95
99
pass
96
100
97
101
# 3. Most expensive way is through package resources
98
- cls ._local_contracts = PackageManifest .parse_file (
99
- importlib .resources .files ("apepay" ) / "manifest.json"
100
- ).contract_types
101
- return cls ._local_contracts ["StreamManager" ]
102
+ manifest_file = Path (__file__ ).parent / "manifest.json"
103
+ manifest_text = manifest_file .read_text ()
104
+ manifest = PackageManifest .parse_raw (manifest_text )
105
+
106
+ if not manifest or not manifest .contract_types :
107
+ raise ValueError ("Invalid manifest" )
108
+
109
+ cls ._local_contracts = manifest .contract_types
110
+ return cls ._local_contracts .get ("StreamManager" )
102
111
103
112
@property
104
113
def contract (self ) -> ContractInstance :
@@ -157,7 +166,7 @@ def set_validators(
157
166
158
167
def add_validators (
159
168
self ,
160
- * new_validators : Iterable [ _ValidatorItem ] ,
169
+ * new_validators : _ValidatorItem ,
161
170
** txn_kwargs ,
162
171
) -> ReceiptAPI :
163
172
return self .set_validators (
@@ -167,7 +176,7 @@ def add_validators(
167
176
168
177
def remove_validators (
169
178
self ,
170
- * validators : Iterable [ _ValidatorItem ] ,
179
+ * validators : _ValidatorItem ,
171
180
** txn_kwargs ,
172
181
) -> ReceiptAPI :
173
182
return self .set_validators (
@@ -307,14 +316,16 @@ class Stream(BaseInterfaceModel):
307
316
creation_receipt : Optional [ReceiptAPI ] = None
308
317
transaction_hash : Optional [HexBytes ] = None
309
318
310
- @validator ("transaction_hash" , pre = True )
319
+ @field_validator ("transaction_hash" , mode = "before" )
320
+ @classmethod
311
321
def normalize_transaction_hash (cls , value : Any ) -> Optional [HexBytes ]:
312
322
if value :
313
323
return HexBytes (cls .conversion_manager .convert (value , bytes ))
314
324
315
325
return value
316
326
317
- @validator ("creator" , pre = True )
327
+ @field_validator ("creator" , mode = "before" )
328
+ @classmethod
318
329
def validate_addresses (cls , value ):
319
330
return (
320
331
value if isinstance (value , str ) else cls .conversion_manager .convert (value , AddressType )
0 commit comments