diff --git a/codegen/generate_tests.py b/codegen/generate_tests.py index 4df99786..6a3babde 100644 --- a/codegen/generate_tests.py +++ b/codegen/generate_tests.py @@ -47,7 +47,6 @@ def test_{entity_snake_case}_roundtrip(instance: {entity_type}) -> None: """ test_code_java = """\ -{xfail} @pytest.mark.java @given(instance=from_type({entity_type})) def test_{entity_snake_case}_java(instance: {entity_type}, java_tester: JavaTester) -> None: @@ -89,20 +88,10 @@ def main() -> None: ) if entity_type.__type__ is not EntityType.nested: - xfail = ( - "" - if entity_type.__name__ not in "UpdateRaftVoterResponse" - else ( - "@pytest.mark.xfail(" - 'reason="https://github.com/Aiven-Open/kio/issues/215"' - ")" - ) - ) module_code[module_path].append( test_code_java.format( entity_type=entity_type.__name__, entity_snake_case=to_snake_case(entity_type.__name__), - xfail=xfail, ) ) diff --git a/src/kio/serial/_implicit_defaults.py b/src/kio/serial/_implicit_defaults.py new file mode 100644 index 00000000..c93420d1 --- /dev/null +++ b/src/kio/serial/_implicit_defaults.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import MISSING +from dataclasses import Field +from dataclasses import fields +from datetime import timedelta +from types import MappingProxyType +from typing import Final +from typing import TypeVar +from typing import assert_never +from uuid import UUID + +from kio.serial._introspect import EntityField +from kio.serial._introspect import EntityTupleField +from kio.serial._introspect import PrimitiveField +from kio.serial._introspect import PrimitiveTupleField +from kio.serial._introspect import classify_field +from kio.serial._introspect import is_optional +from kio.serial.readers import tz_aware_from_i64 +from kio.static.constants import uuid_zero +from kio.static.primitive import Records +from kio.static.primitive import TZAware +from kio.static.primitive import f64 +from kio.static.primitive import i8 +from kio.static.primitive import i16 +from kio.static.primitive import i32 +from kio.static.primitive import i32Timedelta +from kio.static.primitive import i64 +from kio.static.primitive import i64Timedelta +from kio.static.primitive import u8 +from kio.static.primitive import u16 +from kio.static.primitive import u32 +from kio.static.primitive import u64 +from kio.static.protocol import Entity + +T = TypeVar("T") + +primitive_implicit_defaults: Final[Mapping[type, object]] = MappingProxyType( + { + u8: u8(0), + u16: u16(0), + u32: u32(0), + u64: u64(0), + i8: i8(0), + i16: i16(0), + i32: i32(0), + i64: i64(0), + f64: f64(0.0), + i32Timedelta: i32Timedelta.parse(timedelta(0)), + i64Timedelta: i64Timedelta.parse(timedelta(0)), + TZAware: tz_aware_from_i64(i64(0)), + UUID: uuid_zero, + str: "", + bytes: b"", + } +) + + +def get_implicit_default(field_type: type[T]) -> T: + # Records fields have null as implicit default, supporting this requires changing + # code generation to always expect null for a tagged records field. As of writing + # there are no tagged records fields, or other occurrences where we would need such + # implicit default, so this can be safely deferred. + if issubclass(field_type, Records): + raise NotImplementedError("Tagged record fields are not supported") + + try: + # mypy has no way of typing a mapping as T -> T on a per-item level. + return primitive_implicit_defaults[field_type] # type: ignore[return-value] + except KeyError: + return primitive_implicit_defaults[field_type.__bases__[0]] # type: ignore[return-value] + + +U = TypeVar("U", bound=Entity) + + +def get_tagged_field_default(field: Field[U]) -> U: + if field.default is not MISSING: + return field.default + + if is_optional(field): + raise TypeError("Optional fields should have None as explicit default") + + field_class = classify_field(field) + + if isinstance(field_class, PrimitiveField): + return get_implicit_default(field_class.type_) + elif isinstance(field_class, EntityField): + return field.type( + **{ + nested_field.name: get_tagged_field_default(nested_field) + for nested_field in fields(field.type) + } + ) + elif isinstance(field_class, PrimitiveTupleField | EntityTupleField): + raise TypeError("Tuple fields should have the empty tuple as explicit default") + + assert_never(field_class) diff --git a/src/kio/serial/_parse.py b/src/kio/serial/_parse.py index ee1a9417..4e6c3135 100644 --- a/src/kio/serial/_parse.py +++ b/src/kio/serial/_parse.py @@ -10,6 +10,7 @@ from kio.static.protocol import Entity from . import readers +from ._implicit_defaults import get_tagged_field_default from ._introspect import EntityField from ._introspect import EntityTupleField from ._introspect import PrimitiveField @@ -165,7 +166,11 @@ def entity_reader( is_tagged_field=tag is not None, ) if tag is not None: - tagged_field_readers[tag] = field, field_reader + tagged_field_readers[tag] = ( + field, + field_reader, + get_tagged_field_default(field), + ) else: field_readers[field] = field_reader @@ -185,12 +190,17 @@ def read_entity(buffer: IO[bytes]) -> E: return entity_type(**kwargs) # Read tagged fields. + tagged_field_values = {} num_tagged_fields = readers.read_unsigned_varint(buffer) for _ in range(num_tagged_fields): field_tag = readers.read_unsigned_varint(buffer) readers.read_unsigned_varint(buffer) # field length - field, field_reader = tagged_field_readers[field_tag] - kwargs[field.name] = field_reader(buffer) + field, field_reader, _ = tagged_field_readers[field_tag] + tagged_field_values[field.name] = field_reader(buffer) + + # Resolve tagged field implicit defaults. + for field, _, implicit_default in tagged_field_readers.values(): + kwargs[field.name] = tagged_field_values.get(field.name, implicit_default) return entity_type(**kwargs) diff --git a/src/kio/serial/_serialize.py b/src/kio/serial/_serialize.py index b346db3a..5f0365e4 100644 --- a/src/kio/serial/_serialize.py +++ b/src/kio/serial/_serialize.py @@ -1,5 +1,6 @@ import io +from dataclasses import MISSING from dataclasses import Field from dataclasses import fields from typing import Literal @@ -11,6 +12,7 @@ from kio.static.protocol import Entity from . import writers +from ._implicit_defaults import get_tagged_field_default from ._introspect import EntityField from ._introspect import EntityTupleField from ._introspect import PrimitiveField @@ -183,7 +185,11 @@ def entity_writer(entity_type: type[E], nullable: bool = False) -> Writer[E | No is_tag=tag is not None, ) if tag is not None: - tagged_field_writers[tag] = field, field_writer + tagged_field_writers[tag] = ( + field, + field_writer, + get_tagged_field_default(field), + ) else: field_writers[field] = field_writer @@ -212,11 +218,17 @@ def write_entity(buffer: Writable, entity: E) -> None: num_tagged_fields = 0 with io.BytesIO() as tag_buffer: # Serialize tagged fields. Note that order is important to fulfill spec. - for tag, (field, field_writer) in tagged_field_writers.items(): + for tag, ( + field, + field_writer, + implicit_default, + ) in tagged_field_writers.items(): field_value = getattr(entity, field.name) # Skip default-valued fields. - if field_value == field.default: + if field_value == field.default or ( + field.default == MISSING and field_value == implicit_default + ): continue # Write the tag to the buffer and increase counter. diff --git a/src/kio/serial/readers.py b/src/kio/serial/readers.py index efd424e1..bcaeee62 100644 --- a/src/kio/serial/readers.py +++ b/src/kio/serial/readers.py @@ -204,7 +204,7 @@ def read_timedelta_i64(buffer: IO[bytes]) -> i64Timedelta: return datetime.timedelta(milliseconds=read_int64(buffer)) # type: ignore[return-value] -def _tz_aware_from_i64(timestamp: i64) -> TZAware: +def tz_aware_from_i64(timestamp: i64) -> TZAware: dt = datetime.datetime.fromtimestamp(timestamp / 1000, datetime.UTC) try: return TZAware.truncate(dt) @@ -213,11 +213,11 @@ def _tz_aware_from_i64(timestamp: i64) -> TZAware: def read_datetime_i64(buffer: IO[bytes]) -> TZAware: - return _tz_aware_from_i64(read_int64(buffer)) + return tz_aware_from_i64(read_int64(buffer)) def read_nullable_datetime_i64(buffer: IO[bytes]) -> TZAware | None: timestamp = read_int64(buffer) if timestamp == -1: return None - return _tz_aware_from_i64(timestamp) + return tz_aware_from_i64(timestamp) diff --git a/tests/generated/test_update_raft_voter_v0_response.py b/tests/generated/test_update_raft_voter_v0_response.py index 55232f34..dedc2d12 100644 --- a/tests/generated/test_update_raft_voter_v0_response.py +++ b/tests/generated/test_update_raft_voter_v0_response.py @@ -44,7 +44,6 @@ def test_update_raft_voter_response_roundtrip( assert instance == result -@pytest.mark.xfail(reason="https://github.com/Aiven-Open/kio/issues/215") @pytest.mark.java @given(instance=from_type(UpdateRaftVoterResponse)) def test_update_raft_voter_response_java( diff --git a/tests/serial/test_implicit_defaults.py b/tests/serial/test_implicit_defaults.py new file mode 100644 index 00000000..5921a5cc --- /dev/null +++ b/tests/serial/test_implicit_defaults.py @@ -0,0 +1,118 @@ +from dataclasses import dataclass +from dataclasses import fields +from datetime import timedelta +from uuid import UUID + +import pytest + +from kio.schema.types import BrokerId +from kio.schema.types import GroupId +from kio.schema.types import ProducerId +from kio.schema.types import TopicName +from kio.schema.types import TransactionalId +from kio.serial._implicit_defaults import get_implicit_default +from kio.serial._implicit_defaults import get_tagged_field_default +from kio.serial.readers import tz_aware_from_i64 +from kio.static.primitive import Records +from kio.static.primitive import TZAware +from kio.static.primitive import f64 +from kio.static.primitive import i8 +from kio.static.primitive import i16 +from kio.static.primitive import i32 +from kio.static.primitive import i32Timedelta +from kio.static.primitive import i64 +from kio.static.primitive import i64Timedelta +from kio.static.primitive import u8 +from kio.static.primitive import u16 +from kio.static.primitive import u32 +from kio.static.primitive import u64 + + +class TestGetImplicitDefault: + @pytest.mark.parametrize( + ("annotation", "expected"), + ( + (u8, 0), + (u16, 0), + (u32, 0), + (u64, 0), + (i8, 0), + (i16, 0), + (i32, 0), + (i64, 0), + (f64, 0.0), + (i32Timedelta, timedelta(0)), + (i64Timedelta, timedelta(0)), + (TZAware, tz_aware_from_i64(i64(0))), + (UUID, UUID(int=0)), + (str, ""), + (bytes, b""), + (BrokerId, 0), + (GroupId, ""), + (ProducerId, 0), + (TopicName, ""), + (TransactionalId, ""), + ), + ) + def test_returns_expected_value(self, annotation: type, expected: object) -> None: + assert get_implicit_default(annotation) == expected + + def test_raises_not_implemented_error_for_records(self) -> None: + with pytest.raises(NotImplementedError): + get_implicit_default(Records) + + +class TestGetTaggedFieldDefault: + def test_raises_type_error_for_optional_field(self) -> None: + @dataclass + class A: + a: u8 | None + + [field] = fields(A) + + with pytest.raises( + TypeError, + match=r"Optional fields should have None as explicit default", + ): + get_tagged_field_default(field) + + def test_raises_type_error_for_tuple_field(self) -> None: + @dataclass + class A: + a: tuple[u8, ...] + + [field] = fields(A) + + with pytest.raises( + TypeError, + match=r"Tuple fields should have the empty tuple as explicit default", + ): + get_tagged_field_default(field) + + def test_can_get_default_for_primitive_field(self) -> None: + @dataclass + class A: + a: u8 + + [field] = fields(A) + assert get_tagged_field_default(field) == 0 + + def test_can_get_default_for_entity_field(self) -> None: + @dataclass + class A: + a: u8 + + @dataclass + class B: + b: A + + [field] = fields(B) + assert get_tagged_field_default(field) == A(a=u8(0)) + + def test_returns_explicit_default_if_defined(self) -> None: + @dataclass + class A: + a: u8 = u8(1) + + [field] = fields(A) + assert get_tagged_field_default(field) == u8(1) diff --git a/tests/serial/test_tagged_fields.py b/tests/serial/test_tagged_fields.py index ca1f1d0c..5f603134 100644 --- a/tests/serial/test_tagged_fields.py +++ b/tests/serial/test_tagged_fields.py @@ -25,6 +25,16 @@ from kio.static.primitive import u8 +@dataclass(frozen=True, slots=True, kw_only=True) +class Nested: + __type__: ClassVar = EntityType.data + __version__: ClassVar[i16] = i16(0) + __flexible__: ClassVar[bool] = True + __api_key__: ClassVar[i16] = i16(-1) + str_field: str = field(metadata={"kafka_type": "string"}) + int_field: u8 = field(metadata={"kafka_type": "uint8"}) + + @dataclass(frozen=True, slots=True, kw_only=True) class Person: __type__: ClassVar = EntityType.data @@ -38,6 +48,7 @@ class Person: metadata={"kafka_type": "string", "tag": 1}, default=None, ) + tagged_struct: Nested = field(metadata={"tag": 2}) read_person = entity_reader(Person) @@ -57,14 +68,42 @@ class WritableTag(Generic[T]): [ ( [WritableTag(tag=0, writer=write_uint8, value=u8(123))], - Person(age=u8(123)), + Person( + age=u8(123), + tagged_struct=Nested(str_field="", int_field=u8(0)), + ), ), ( [ WritableTag(tag=0, writer=write_uint8, value=u8(12)), WritableTag(tag=1, writer=write_compact_string, value="Borduria"), ], - Person(age=u8(12), country="Borduria"), + Person( + age=u8(12), + country="Borduria", + tagged_struct=Nested(str_field="", int_field=u8(0)), + ), + ), + ( + [WritableTag(tag=1, writer=write_compact_string, value="Borduria")], + Person( + age=u8(0), + country="Borduria", + tagged_struct=Nested(str_field="", int_field=u8(0)), + ), + ), + ( + [ + WritableTag( + tag=2, + writer=entity_writer(Nested), + value=Nested(str_field="hello", int_field=u8(123)), + ), + ], + Person( + age=u8(0), + tagged_struct=Nested(str_field="hello", int_field=u8(123)), + ), ), ], ) @@ -89,24 +128,6 @@ def test_can_parse_tagged_fields( assert read_person(buffer) == expected -def test_raises_type_error_when_missing_required_tagged_field( - buffer: io.BytesIO, -) -> None: - write_compact_string(buffer, "Almaszout") # name - - write_unsigned_varint(buffer, 1) # num tagged fields - # Only write country, omit age. - write_tagged_field(buffer, 1, write_compact_string, "Borduria") - - buffer.seek(0) - - with pytest.raises( - TypeError, - match=r"missing 1 required keyword-only argument: 'age'", - ): - read_person(buffer) - - @dataclass(frozen=True, slots=True, kw_only=True) class ReadableTag(Generic[T]): tag: int @@ -118,20 +139,55 @@ class ReadableTag(Generic[T]): ("instance", "expected_tags"), [ ( - Person(age=u8(123)), + Person( + age=u8(123), + tagged_struct=Nested(str_field="", int_field=u8(0)), + ), [ReadableTag(tag=0, reader=read_uint8, value=u8(123))], ), ( - Person(age=u8(12), country="Borduria"), + Person( + age=u8(12), + country="Borduria", + tagged_struct=Nested(str_field="", int_field=u8(0)), + ), [ ReadableTag(tag=0, reader=read_uint8, value=u8(12)), ReadableTag(tag=1, reader=read_compact_string, value="Borduria"), ], ), ( - Person(age=u8(1), country=None), + Person( + age=u8(1), + country=None, + tagged_struct=Nested(str_field="", int_field=u8(0)), + ), [ReadableTag(tag=0, reader=read_uint8, value=u8(1))], ), + ( + Person(age=u8(0), tagged_struct=Nested(str_field="", int_field=u8(0))), + [], + ), + ( + Person(age=u8(0), tagged_struct=Nested(str_field="hello", int_field=u8(0))), + [ + ReadableTag( + tag=2, + reader=entity_reader(Nested), + value=Nested(str_field="hello", int_field=u8(0)), + ) + ], + ), + ( + Person(age=u8(0), tagged_struct=Nested(str_field="", int_field=u8(1))), + [ + ReadableTag( + tag=2, + reader=entity_reader(Nested), + value=Nested(str_field="", int_field=u8(1)), + ) + ], + ), ], ) def test_can_serialize_tagged_fields(