Skip to content

Commit 4582584

Browse files
Add NonZeroType (#1527)
1 parent 4922984 commit 4582584

File tree

7 files changed

+141
-3
lines changed

7 files changed

+141
-3
lines changed

docs/migration_guide.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
Migration guide
22
===============
33

4+
******************************
5+
[unreleased] Migration guide
6+
******************************
7+
8+
[unreleased] Minor changes
9+
--------------------------
10+
11+
.. currentmodule:: starknet_py.cairo.data_types
12+
13+
1. Added :class:`NonZeroType` in order to fix parsing ABI which contains Cairo`s `core::zeroable::NonZero <https://github.com/starkware-libs/cairo/blob/a2b9dddeb3212c8d529538454745b27d7a34a6cd/corelib/src/zeroable.cairo#L78>`_
14+
415
******************************
516
0.24.3 Migration guide
617
******************************

starknet_py/abi/v2/parser_transformer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from math import log2
2-
from typing import Any, List, Optional
2+
from typing import Any, List, Optional, Union
33

44
import lark
55
from lark import Token, Transformer
@@ -9,6 +9,7 @@
99
BoolType,
1010
CairoType,
1111
FeltType,
12+
NonZeroType,
1213
OptionType,
1314
TupleType,
1415
TypeIdentifier,
@@ -31,6 +32,7 @@
3132
| type_class_hash
3233
| type_storage_address
3334
| type_option
35+
| type_non_zero
3436
| type_array
3537
| type_span
3638
| tuple
@@ -49,6 +51,7 @@
4951
type_option: "core::option::Option::<" (type | type_identifier) ">"
5052
type_array: "core::array::Array::<" (type | type_identifier) ">"
5153
type_span: "core::array::Span::<" (type | type_identifier) ">"
54+
type_non_zero: "core::zeroable::NonZero::<" (type | type_identifier) ">"
5255
5356
tuple: "(" type? ("," type?)* ")"
5457
@@ -185,6 +188,12 @@ def tuple(self, types: List[CairoType]) -> TupleType:
185188
"""
186189
return TupleType(types)
187190

191+
def type_non_zero(self, value: List[Union[FeltType, UintType]]) -> NonZeroType:
192+
"""
193+
NonZero contains value which is never zero.
194+
"""
195+
return NonZeroType(value[0])
196+
188197

189198
def parse(
190199
code: str,

starknet_py/cairo/data_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,12 @@ class EventType(CairoType):
122122
name: str
123123
types: OrderedDict[str, CairoType]
124124
keys: List[str]
125+
126+
127+
@dataclass
128+
class NonZeroType(CairoType):
129+
"""
130+
Type representation of Cairo NonZero.
131+
"""
132+
133+
type: CairoType
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Generator
3+
4+
from starknet_py.serialization._context import (
5+
Context,
6+
DeserializationContext,
7+
SerializationContext,
8+
)
9+
from starknet_py.serialization.data_serializers.cairo_data_serializer import (
10+
CairoDataSerializer,
11+
)
12+
13+
14+
@dataclass
15+
class NonZeroSerializer(CairoDataSerializer[Any, int]):
16+
"""
17+
Serializer for NonZero type.
18+
Can serialize Cairo types which are non-zero.
19+
Deserializes data to int.
20+
"""
21+
22+
serializer: CairoDataSerializer
23+
24+
def deserialize_with_context(self, context: DeserializationContext) -> int:
25+
deserialized_value = self.serializer.deserialize_with_context(context)
26+
self._ensure_valid_nonzero(deserialized_value, context)
27+
return deserialized_value
28+
29+
def serialize_with_context(
30+
self,
31+
context: SerializationContext,
32+
value: Any,
33+
) -> Generator[int, None, None]:
34+
self._ensure_valid_nonzero(value, context)
35+
return self.serializer.serialize_with_context(context, value)
36+
37+
@staticmethod
38+
def _ensure_valid_nonzero(value: int, context: Context):
39+
context.ensure_valid_value(value != 0, "expected value to be non-zero")

starknet_py/serialization/factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
EventType,
1515
FeltType,
1616
NamedTupleType,
17+
NonZeroType,
1718
OptionType,
1819
StructType,
1920
TupleType,
@@ -33,6 +34,9 @@
3334
from starknet_py.serialization.data_serializers.named_tuple_serializer import (
3435
NamedTupleSerializer,
3536
)
37+
from starknet_py.serialization.data_serializers.non_zero_serializer import (
38+
NonZeroSerializer,
39+
)
3640
from starknet_py.serialization.data_serializers.option_serializer import (
3741
OptionSerializer,
3842
)
@@ -119,6 +123,9 @@ def serializer_for_type(cairo_type: CairoType) -> CairoDataSerializer:
119123
if isinstance(cairo_type, OptionType):
120124
return OptionSerializer(serializer_for_type(cairo_type.type))
121125

126+
if isinstance(cairo_type, NonZeroType):
127+
return NonZeroSerializer(serializer_for_type(cairo_type.type))
128+
122129
if isinstance(cairo_type, UnitType):
123130
return UnitSerializer()
124131

starknet_py/tests/e2e/mock/contracts_v2/src/abi_types.cairo

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use core::serde::Serde;
22
use starknet::ContractAddress;
3+
use core::zeroable::NonZero;
4+
use core::zeroable::IsZeroResult;
5+
use core::zeroable::NonZeroIntoImpl;
6+
use core::integer::{u8_try_as_non_zero};
37

48
#[derive(Drop, Serde)]
59
enum ExampleEnum {
@@ -13,6 +17,8 @@ struct ExampleStruct {
1317
field_b: felt252,
1418
field_c: ExampleEnum,
1519
field_d: (),
20+
field_e: NonZero<felt252>,
21+
field_f: NonZero<u8>,
1622
}
1723

1824
#[starknet::interface]
@@ -23,13 +29,26 @@ trait IAbiTest<TContractState> {
2329
) -> ExampleStruct;
2430
}
2531

32+
pub fn felt_to_nonzero(value: felt252) -> NonZero<felt252> {
33+
match felt252_is_zero(value) {
34+
IsZeroResult::Zero(()) => panic(ArrayTrait::new()),
35+
IsZeroResult::NonZero(x) => x,
36+
}
37+
}
38+
39+
pub fn u8_to_nonzero(value: u8) -> NonZero<u8> {
40+
match u8_try_as_non_zero(value) {
41+
Option::Some(x) => x,
42+
Option::None => panic(ArrayTrait::new()),
43+
}
44+
}
2645

2746
#[starknet::contract]
2847
mod AbiTypes {
2948
use core::array::ArrayTrait;
3049
use core::traits::Into;
3150
use starknet::ContractAddress;
32-
use super::{ExampleEnum, ExampleStruct};
51+
use super::{ExampleEnum, ExampleStruct, felt_to_nonzero, u8_to_nonzero};
3352

3453
#[storage]
3554
struct Storage {}
@@ -46,7 +65,12 @@ mod AbiTypes {
4665
ref self: ContractState, recipient: ContractAddress, amount: u256
4766
) -> ExampleStruct {
4867
ExampleStruct {
49-
field_a: 200, field_b: 300, field_c: ExampleEnum::variant_b(400.into()), field_d: ()
68+
field_a: 200,
69+
field_b: 300,
70+
field_c: ExampleEnum::variant_b(400.into()),
71+
field_d: (),
72+
field_e: felt_to_nonzero(100),
73+
field_f: u8_to_nonzero(100),
5074
}
5175
}
5276
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
3+
from starknet_py.constants import FIELD_PRIME
4+
from starknet_py.serialization import FeltSerializer
5+
from starknet_py.serialization.data_serializers.non_zero_serializer import (
6+
NonZeroSerializer,
7+
)
8+
from starknet_py.serialization.data_serializers.uint_serializer import UintSerializer
9+
10+
11+
@pytest.mark.parametrize(
12+
"serializer, value, serialized_value",
13+
[
14+
(NonZeroSerializer(UintSerializer(128)), 123, [123]),
15+
(NonZeroSerializer(UintSerializer(256)), 1, [1, 0]),
16+
(NonZeroSerializer(FeltSerializer()), 10, [10]),
17+
(NonZeroSerializer(FeltSerializer()), FIELD_PRIME - 1, [FIELD_PRIME - 1]),
18+
],
19+
)
20+
def test_valid_values(serializer, value, serialized_value):
21+
deserialized = serializer.deserialize(serialized_value)
22+
assert deserialized == value
23+
24+
serialized = serializer.serialize(value)
25+
assert serialized == serialized_value
26+
27+
28+
@pytest.mark.parametrize(
29+
"serializer, value, serialized_value",
30+
[
31+
(NonZeroSerializer(UintSerializer(128)), 0, [0]),
32+
(NonZeroSerializer(UintSerializer(256)), 0, [0, 0]),
33+
(NonZeroSerializer(FeltSerializer()), 0, [0]),
34+
],
35+
)
36+
def test_invalid_values(serializer, value, serialized_value):
37+
with pytest.raises(ValueError, match="expected value to be non-zero"):
38+
serializer.deserialize(serialized_value)
39+
serializer.serialize(value)

0 commit comments

Comments
 (0)