From 3d0030c947be6a72aa5f51c3997a77b7d2cd0657 Mon Sep 17 00:00:00 2001 From: slush Date: Mon, 14 Oct 2024 20:38:27 -0500 Subject: [PATCH 1/4] refactor!:replaces all string types with appropriate validated types from eth_pydantic_types library (cherry picked from commit 0df69cda8efc656f2fb792ef06dffa938c5f20a2) --- eip712/common.py | 90 ++++++++++++++++--------------- eip712/messages.py | 122 ++++++++++++++++++++++++++---------------- tests/conftest.py | 31 +++++------ tests/test_fuzzing.py | 2 +- 4 files changed, 140 insertions(+), 105 deletions(-) diff --git a/eip712/common.py b/eip712/common.py index 5c8dd61..e971ea7 100644 --- a/eip712/common.py +++ b/eip712/common.py @@ -2,46 +2,52 @@ # Collection of commonly-used EIP712 message type definitions from typing import Optional, Type, Union +from eth_pydantic_types.abi import address, bytes, bytes32, string, uint8, uint256 + from .messages import EIP712Message class EIP2612(EIP712Message): # NOTE: Subclass this w/ at least one header field - owner: "address" # type: ignore - spender: "address" # type: ignore - value: "uint256" # type: ignore - nonce: "uint256" # type: ignore - deadline: "uint256" # type: ignore + owner: address + spender: address + value: uint256 + nonce: uint256 + deadline: uint256 class EIP4494(EIP712Message): # NOTE: Subclass this w/ at least one header field - spender: "address" # type: ignore - tokenId: "uint256" # type: ignore - nonce: "uint256" # type: ignore - deadline: "uint256" # type: ignore + spender: address + tokenId: uint256 + nonce: uint256 + deadline: uint256 def create_permit_def(eip=2612, **header_fields): if eip == 2612: class Permit(EIP2612): - _name_ = header_fields.get("name", None) - _version_ = header_fields.get("version", None) - _chainId_ = header_fields.get("chainId", None) - _verifyingContract_ = header_fields.get("verifyingContract", None) - _salt_ = header_fields.get("salt", None) + eip712_name_: Optional[string] = header_fields.get("name", None) + eip712_version_: Optional[string] = header_fields.get("version", None) + eip712_chainId_: Optional[uint256] = header_fields.get("chainId", None) + eip712_verifyingContract_: Optional[string] = header_fields.get( + "verifyingContract", None + ) + eip712_salt_: Optional[bytes32] = header_fields.get("salt", None) elif eip == 4494: class Permit(EIP4494): - _name_ = header_fields.get("name", None) - _version_ = header_fields.get("version", None) - _chainId_ = header_fields.get("chainId", None) - _verifyingContract_ = header_fields.get("verifyingContract", None) - _salt_ = header_fields.get("salt", None) + eip712_name_: Optional[string] = header_fields.get("name", None) + eip712_version_: Optional[string] = header_fields.get("version", None) + eip712_chainId_: Optional[uint256] = header_fields.get("chainId", None) + eip712_verifyingContract_: Optional[string] = header_fields.get( + "verifyingContract", None + ) + eip712_salt_: Optional[bytes32] = header_fields.get("salt", None) else: raise ValueError(f"Invalid eip {eip}, must use one of: {EIP2612}, {EIP4494}") @@ -51,30 +57,30 @@ class Permit(EIP4494): class SafeTxV1(EIP712Message): # NOTE: Subclass this as `SafeTx` w/ at least one header field - to: "address" # type: ignore - value: "uint256" = 0 # type: ignore - data: "bytes" = b"" - operation: "uint8" = 0 # type: ignore - safeTxGas: "uint256" = 0 # type: ignore - dataGas: "uint256" = 0 # type: ignore - gasPrice: "uint256" = 0 # type: ignore - gasToken: "address" = "0x0000000000000000000000000000000000000000" # type: ignore - refundReceiver: "address" = "0x0000000000000000000000000000000000000000" # type: ignore - nonce: "uint256" # type: ignore + to: address + value: uint256 = 0 + data: bytes = b"" + operation: uint8 = 0 + safeTxGas: uint256 = 0 + dataGas: uint256 = 0 + gasPrice: uint256 = 0 + gasToken: address = "0x0000000000000000000000000000000000000000" + refundReceiver: address = "0x0000000000000000000000000000000000000000" + nonce: uint256 class SafeTxV2(EIP712Message): # NOTE: Subclass this as `SafeTx` w/ at least one header field - to: "address" # type: ignore - value: "uint256" = 0 # type: ignore - data: "bytes" = b"" - operation: "uint8" = 0 # type: ignore - safeTxGas: "uint256" = 0 # type: ignore - baseGas: "uint256" = 0 # type: ignore - gasPrice: "uint256" = 0 # type: ignore - gasToken: "address" = "0x0000000000000000000000000000000000000000" # type: ignore - refundReceiver: "address" = "0x0000000000000000000000000000000000000000" # type: ignore - nonce: "uint256" # type: ignore + to: address + value: uint256 = 0 + data: bytes = b"" + operation: uint8 = 0 + safeTxGas: uint256 = 0 + baseGas: uint256 = 0 + gasPrice: uint256 = 0 + gasToken: address = "0x0000000000000000000000000000000000000000" + refundReceiver: address = "0x0000000000000000000000000000000000000000" + nonce: uint256 SafeTx = Union[SafeTxV1, SafeTxV2] @@ -97,7 +103,7 @@ def create_safe_tx_def( if minor < 3: class SafeTx(SafeTxV1): - _verifyingContract_ = contract_address + eip712_verifyingContract_: address = contract_address elif not chain_id: raise ValueError("Must supply 'chain_id=' for Safe versions 1.3.0 or later") @@ -105,7 +111,7 @@ class SafeTx(SafeTxV1): else: class SafeTx(SafeTxV2): # type: ignore[no-redef] - _chainId_ = chain_id - _verifyingContract_ = contract_address + eip712_chainId_: uint256 = chain_id + eip712_verifyingContract_: address = contract_address return SafeTx diff --git a/eip712/messages.py b/eip712/messages.py index 9739fe2..76f3eb4 100644 --- a/eip712/messages.py +++ b/eip712/messages.py @@ -4,12 +4,14 @@ from typing import Any, Optional -from dataclassy import asdict, dataclass, fields from eth_abi.abi import is_encodable_type # type: ignore[import-untyped] from eth_account.messages import SignableMessage, hash_domain, hash_eip712_message +from eth_pydantic_types import Address, HexBytes +from eth_pydantic_types.abi import bytes32, string, uint256 from eth_utils import keccak from eth_utils.curried import ValidationError -from hexbytes import HexBytes +from pydantic import BaseModel, model_validator +from typing_extensions import _AnnotatedAlias # ! Do not change the order of the fields in this list ! # To correctly encode and hash the domain fields, they @@ -30,8 +32,7 @@ ] -@dataclass(iter=True, slots=True, kwargs=True, kw_only=True) -class EIP712Type: +class EIP712Type(BaseModel): """ Dataclass for `EIP-712 `__ structured data types (i.e. the contents of an :class:`EIP712Message`). @@ -48,38 +49,65 @@ def _types_(self) -> dict: """ types: dict[str, list] = {repr(self): []} - for field in fields(self.__class__): + for field in { + k: v.annotation.__name__ + for k, v in self.model_fields.items() + if not k.startswith("eip712_") + }: value = getattr(self, field) if isinstance(value, EIP712Type): types[repr(self)].append({"name": field, "type": repr(value)}) types.update(value._types_) else: - # TODO: Use proper ABI typing, not strings - field_type = self.__annotations__[field] + field_type = search_annotations(self, field) + # If the field type is a string, validate through eth-abi if isinstance(field_type, str): if not is_encodable_type(field_type): - raise ValidationError(f"'{field}: {field_type}' is not a valid ABI type") + raise ValidationError(f"'{field}: {field_type}' is not a valid ABI Type") - elif issubclass(field_type, EIP712Type): + elif isinstance(field_type, type) and issubclass(field_type, EIP712Type): field_type = repr(field_type) else: - raise ValidationError( - f"'{field}' type annotation must either be a subclass of " - f"`EIP712Type` or valid ABI Type string, not {field_type.__name__}" - ) + try: + # If field type already has validators or is a known type + # can confirm that type name will be correct + if isinstance(field_type.__value__, _AnnotatedAlias) or issubclass( + field_type.__value__, (Address, HexBytes) + ): + field_type = field_type.__name__ + + except AttributeError: + raise ValidationError( + f"'{field}' type annotation must either be a subclass of " + f"`EIP712Type` or valid ABI Type, not {field_type.__name__}" + ) types[repr(self)].append({"name": field, "type": field_type}) return types def __getitem__(self, key: str) -> Any: - if (key.startswith("_") and key.endswith("_")) or key not in fields(self.__class__): + if (key.startswith("_") and key.endswith("_")) or key not in self.model_fields: raise KeyError("Cannot look up header fields or other attributes this way") return getattr(self, key) + def _prepare_data_for_hashing(self, data: dict) -> dict: + result: dict = {} + + for key, value in data.items(): + item: Any = value + if isinstance(value, EIP712Type): + item = value.model_dump(mode="json") + elif isinstance(value, dict): + item = self._prepare_data_for_hashing(item) + + result[key] = item + + return result + class EIP712Message(EIP712Type): """ @@ -88,19 +116,22 @@ class EIP712Message(EIP712Type): """ # NOTE: Must override at least one of these fields - _name_: Optional[str] = None - _version_: Optional[str] = None - _chainId_: Optional[int] = None - _verifyingContract_: Optional[str] = None - _salt_: Optional[bytes] = None - - def __post_init__(self): + eip712_name_: Optional[string] = None + eip712_version_: Optional[string] = None + eip712_chainId_: Optional[uint256] = None + eip712_verifyingContract_: Optional[string] = None + eip712_salt_: Optional[bytes32] = None + + @model_validator(mode="after") + @classmethod + def validate_model(cls, value): # At least one of the header fields must be in the EIP712 message header - if not any(getattr(self, f"_{field}_") for field in EIP712_DOMAIN_FIELDS): + if not any(f"eip712_{field}_" in value.__annotations__ for field in EIP712_DOMAIN_FIELDS): raise ValidationError( - f"EIP712 Message definition '{repr(self)}' must define " - f"at least one of: _{'_, _'.join(EIP712_DOMAIN_FIELDS)}_" + f"EIP712 Message definition '{repr(cls)}' must define " + f"at least one of: eip712_{'_, eip712_'.join(EIP712_DOMAIN_FIELDS)}_" ) + return value @property def _domain_(self) -> dict: @@ -108,13 +139,15 @@ def _domain_(self) -> dict: domain_type = [ {"name": field, "type": abi_type} for field, abi_type in EIP712_DOMAIN_FIELDS.items() - if getattr(self, f"_{field}_") + if getattr(self, f"eip712_{field}_") ] return { "types": { "EIP712Domain": domain_type, }, - "domain": {field["name"]: getattr(self, f"_{field['name']}_") for field in domain_type}, + "domain": { + field["name"]: getattr(self, f"eip712_{field['name']}_") for field in domain_type + }, } @property @@ -126,9 +159,10 @@ def _body_(self) -> dict: "types": dict(self._types_, **self._domain_["types"]), "primaryType": repr(self), "message": { + # TODO use __pydantic_extra__ instead key: getattr(self, key) - for key in fields(self.__class__) - if not key.startswith("_") or not key.endswith("_") + for key in self.model_fields + if not key.startswith("eip712_") or not key.endswith("_") }, } @@ -144,13 +178,16 @@ def signable_message(self) -> SignableMessage: The current message as a :class:`SignableMessage` named tuple instance. **NOTE**: The 0x19 prefix is NOT included. """ - domain = _prepare_data_for_hashing(self._domain_["domain"]) - types = _prepare_data_for_hashing(self._types_) - message = _prepare_data_for_hashing(self._body_["message"]) + domain = self._prepare_data_for_hashing(self._domain_["domain"]) + types = self._prepare_data_for_hashing(self._types_) + message = self._prepare_data_for_hashing(self._body_["message"]) + messagebytes = HexBytes(1) + messageDomain = HexBytes(hash_domain(domain)) + messageEIP = HexBytes(hash_eip712_message(types, message)) return SignableMessage( - HexBytes(1), - HexBytes(hash_domain(domain)), - HexBytes(hash_eip712_message(types, message)), + messagebytes, + messageDomain, + messageEIP, ) @@ -158,16 +195,7 @@ def calculate_hash(msg: SignableMessage) -> HexBytes: return HexBytes(keccak(b"".join([bytes.fromhex("19"), *msg]))) -def _prepare_data_for_hashing(data: dict) -> dict: - result: dict = {} - - for key, value in data.items(): - item: Any = value - if isinstance(value, EIP712Type): - item = asdict(value) - elif isinstance(value, dict): - item = _prepare_data_for_hashing(item) - - result[key] = item - - return result +def search_annotations(cls, field: str) -> Any: + if hasattr(cls, "__annotations__") and field in cls.__annotations__: + return cls.__annotations__[field] + return search_annotations(super(cls.__class__, cls), field) diff --git a/tests/conftest.py b/tests/conftest.py index b1a6996..edd1158 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import pytest +from eth_pydantic_types.abi import address, bytes32, string, uint256 from hexbytes import HexBytes from eip712.common import create_permit_def @@ -17,34 +18,34 @@ class SubType(EIP712Type): - inner: "uint256" # type: ignore + inner: uint256 class ValidMessageWithNameDomainField(EIP712Message): - _name_ = "Valid Test Message" - value: "uint256" # type: ignore - default_value: "address" = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" # type: ignore + eip712_name_: string = "Valid Test Message" + value: uint256 + default_value: address = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" # type: ignore sub: SubType class MessageWithNonCanonicalDomainFieldOrder(EIP712Message): - _name_ = PERMIT_NAME - _salt_ = PERMIT_SALT - _chainId_ = PERMIT_CHAIN_ID - _version_ = PERMIT_VERSION - _verifyingContract_ = PERMIT_VAULT_ADDRESS + eip712_name_: string = PERMIT_NAME + eip712_salt_: bytes32 = PERMIT_SALT + eip712_chainId_: uint256 = PERMIT_CHAIN_ID + eip712_version_: string = PERMIT_VERSION + eip712_verifyingContract_: address = PERMIT_VAULT_ADDRESS class MessageWithCanonicalDomainFieldOrder(EIP712Message): - _name_ = PERMIT_NAME - _version_ = PERMIT_VERSION - _chainId_ = PERMIT_CHAIN_ID - _verifyingContract_ = PERMIT_VAULT_ADDRESS - _salt_ = PERMIT_SALT + eip712_name_: string = PERMIT_NAME + eip712_version_: string = PERMIT_VERSION + eip712_chainId_: uint256 = PERMIT_CHAIN_ID + eip712_verifyingContract_: address = PERMIT_VAULT_ADDRESS + eip712_salt_: bytes32 = PERMIT_SALT class InvalidMessageMissingDomainFields(EIP712Message): - value: "uint256" # type: ignore + value: uint256 @pytest.fixture diff --git a/tests/test_fuzzing.py b/tests/test_fuzzing.py index 4d2cfef..d109ee0 100644 --- a/tests/test_fuzzing.py +++ b/tests/test_fuzzing.py @@ -26,7 +26,7 @@ def test_random_message_def(types, data): exec( f"""class Msg(EIP712Message): - _name_="test def" + eip712_name_:str="test def" {members_str}""", globals(), ) # Creates `Msg` definition From 31392b59be1dfdaab53ae6dab7d09e6257700062 Mon Sep 17 00:00:00 2001 From: slush Date: Wed, 11 Dec 2024 10:49:45 -0600 Subject: [PATCH 2/4] feat: adds protected prefix to __getitem__ check --- eip712/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eip712/messages.py b/eip712/messages.py index 76f3eb4..fb3aca1 100644 --- a/eip712/messages.py +++ b/eip712/messages.py @@ -89,7 +89,7 @@ def _types_(self) -> dict: return types def __getitem__(self, key: str) -> Any: - if (key.startswith("_") and key.endswith("_")) or key not in self.model_fields: + if (key.startswith("eip712_") and key.endswith("_")) or key not in self.model_fields: raise KeyError("Cannot look up header fields or other attributes this way") return getattr(self, key) From 356070b021f3b11b7768cf770186ed32a350bec7 Mon Sep 17 00:00:00 2001 From: slush Date: Wed, 18 Dec 2024 01:06:21 -0600 Subject: [PATCH 3/4] feat: update deps --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2a39fb3..c302520 100644 --- a/setup.py +++ b/setup.py @@ -69,9 +69,9 @@ url="https://github.com/ApeWorX/eip712", include_package_data=True, install_requires=[ - "dataclassy>=0.11.1,<1", "eth-abi>=5.1.0,<6", "eth-account>=0.11.3,<0.14", + "eth-pydantic-types>=0.2.0,<1", "eth-typing>=3.5.2,<6", "eth-utils>=2.3.1,<6", "hexbytes>=0.3.1,<2", From 2b08a2a613b4ce9a001d9cb3f680bd26b471be77 Mon Sep 17 00:00:00 2001 From: slush Date: Wed, 18 Dec 2024 01:06:45 -0600 Subject: [PATCH 4/4] feat: import perf improvements --- eip712/messages.py | 18 ++++++++++-------- tests/conftest.py | 36 ++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/eip712/messages.py b/eip712/messages.py index fb3aca1..c9feff2 100644 --- a/eip712/messages.py +++ b/eip712/messages.py @@ -2,17 +2,19 @@ Message classes for typed structured data hashing and signing in Ethereum. """ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from eth_abi.abi import is_encodable_type # type: ignore[import-untyped] from eth_account.messages import SignableMessage, hash_domain, hash_eip712_message from eth_pydantic_types import Address, HexBytes -from eth_pydantic_types.abi import bytes32, string, uint256 from eth_utils import keccak from eth_utils.curried import ValidationError from pydantic import BaseModel, model_validator from typing_extensions import _AnnotatedAlias +if TYPE_CHECKING: + from eth_pydantic_types.abi import bytes32, string, uint256 + # ! Do not change the order of the fields in this list ! # To correctly encode and hash the domain fields, they # must be in this precise order. @@ -50,7 +52,7 @@ def _types_(self) -> dict: types: dict[str, list] = {repr(self): []} for field in { - k: v.annotation.__name__ + k: v.annotation.__name__ # type: ignore[union-attr] for k, v in self.model_fields.items() if not k.startswith("eip712_") }: @@ -116,11 +118,11 @@ class EIP712Message(EIP712Type): """ # NOTE: Must override at least one of these fields - eip712_name_: Optional[string] = None - eip712_version_: Optional[string] = None - eip712_chainId_: Optional[uint256] = None - eip712_verifyingContract_: Optional[string] = None - eip712_salt_: Optional[bytes32] = None + eip712_name_: Optional["string"] = None + eip712_version_: Optional["string"] = None + eip712_chainId_: Optional["uint256"] = None + eip712_verifyingContract_: Optional["string"] = None + eip712_salt_: Optional["bytes32"] = None @model_validator(mode="after") @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index edd1158..c7b1701 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,14 @@ +from typing import TYPE_CHECKING + import pytest -from eth_pydantic_types.abi import address, bytes32, string, uint256 from hexbytes import HexBytes from eip712.common import create_permit_def from eip712.messages import EIP712Message, EIP712Type +if TYPE_CHECKING: + from eth_pydantic_types.abi import address, bytes32, string, uint256 + PERMIT_NAME = "Yearn Vault" PERMIT_VERSION = "0.3.5" PERMIT_CHAIN_ID = 1 @@ -18,34 +22,34 @@ class SubType(EIP712Type): - inner: uint256 + inner: "uint256" class ValidMessageWithNameDomainField(EIP712Message): - eip712_name_: string = "Valid Test Message" - value: uint256 - default_value: address = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" # type: ignore + eip712_name_: "string" = "Valid Test Message" + value: "uint256" + default_value: "address" = "0xFFfFfFffFFfffFFfFFfFFFFFffFFFffffFfFFFfF" sub: SubType class MessageWithNonCanonicalDomainFieldOrder(EIP712Message): - eip712_name_: string = PERMIT_NAME - eip712_salt_: bytes32 = PERMIT_SALT - eip712_chainId_: uint256 = PERMIT_CHAIN_ID - eip712_version_: string = PERMIT_VERSION - eip712_verifyingContract_: address = PERMIT_VAULT_ADDRESS + eip712_name_: "string" = PERMIT_NAME + eip712_salt_: "bytes32" = PERMIT_SALT + eip712_chainId_: "uint256" = PERMIT_CHAIN_ID + eip712_version_: "string" = PERMIT_VERSION + eip712_verifyingContract_: "address" = PERMIT_VAULT_ADDRESS class MessageWithCanonicalDomainFieldOrder(EIP712Message): - eip712_name_: string = PERMIT_NAME - eip712_version_: string = PERMIT_VERSION - eip712_chainId_: uint256 = PERMIT_CHAIN_ID - eip712_verifyingContract_: address = PERMIT_VAULT_ADDRESS - eip712_salt_: bytes32 = PERMIT_SALT + eip712_name_: "string" = PERMIT_NAME + eip712_version_: "string" = PERMIT_VERSION + eip712_chainId_: "uint256" = PERMIT_CHAIN_ID + eip712_verifyingContract_: "address" = PERMIT_VAULT_ADDRESS + eip712_salt_: "bytes32" = PERMIT_SALT class InvalidMessageMissingDomainFields(EIP712Message): - value: uint256 + value: "uint256" @pytest.fixture