diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7742ade..19127cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,7 +63,7 @@ repos: - --multi-line=9 - --project=pydantic_xml - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.991 + rev: v1.0.0 hooks: - id: mypy stages: diff --git a/examples/snippets/model_template.py b/examples/snippets/model_template.py index b9a4ee8..8b6a0cc 100644 --- a/examples/snippets/model_template.py +++ b/examples/snippets/model_template.py @@ -27,7 +27,7 @@ class Config: class Vehicles(BaseXmlModel, tag='vehicles'): - items: List[Union[Car, Airplane]] + items: List[Union[Car, Airplane]] = element() # [model-end] diff --git a/examples/snippets/union_models.py b/examples/snippets/union_models.py index 791f2da..573a0ac 100644 --- a/examples/snippets/union_models.py +++ b/examples/snippets/union_models.py @@ -18,7 +18,7 @@ class MouseEvent(Event, tag='mouse'): class Log(BaseXmlModel, tag='log'): - events: List[Union[KeyboardEvent, MouseEvent]] + events: List[Union[KeyboardEvent, MouseEvent]] = element() # [model-end] diff --git a/examples/snippets/union_primitives.py b/examples/snippets/union_primitives.py index 434efdc..f4356fa 100644 --- a/examples/snippets/union_primitives.py +++ b/examples/snippets/union_primitives.py @@ -1,7 +1,7 @@ import datetime as dt from typing import List, Optional, Union -from pydantic_xml import BaseXmlModel, attr +from pydantic_xml import BaseXmlModel, attr, element # [model-start] @@ -11,7 +11,7 @@ class Message(BaseXmlModel, tag='Message'): class Messages(BaseXmlModel): - messages: List[Message] + messages: List[Message] = element() # [model-end] diff --git a/pydantic_xml/serializers/factories/homogeneous.py b/pydantic_xml/serializers/factories/homogeneous.py index 09b9eac..77177dd 100644 --- a/pydantic_xml/serializers/factories/homogeneous.py +++ b/pydantic_xml/serializers/factories/homogeneous.py @@ -1,6 +1,6 @@ import dataclasses as dc from copy import deepcopy -from typing import Any, List, Optional, Type +from typing import Any, Collection, List, Optional, Type import pydantic as pd @@ -8,7 +8,7 @@ from pydantic_xml import errors from pydantic_xml.element import XmlElementReader, XmlElementWriter from pydantic_xml.serializers.encoder import XmlEncoder -from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer +from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer, is_xml_model from pydantic_xml.utils import QName, merge_nsmaps @@ -17,6 +17,79 @@ class HomogeneousSerializerFactory: Homogeneous collection type serializer factory. """ + class TextSerializer(Serializer): + def __init__( + self, model: Type['pxml.BaseXmlModel'], model_field: pd.fields.ModelField, ctx: Serializer.Context, + ): + assert model_field.sub_fields and len(model_field.sub_fields) == 1 + if ( + is_xml_model(model_field.type_) or + issubclass(model_field.type_, tuple) + ): + raise errors.ModelFieldError( + model.__name__, model_field.name, "Inline list value should be of scalar type", + ) + + def serialize( + self, element: XmlElementWriter, value: Collection[Any], *, encoder: XmlEncoder, + skip_empty: bool = False, + ) -> Optional[XmlElementWriter]: + if value is None or skip_empty and len(value) == 0: + return element + + encoded = " ".join(encoder.encode(val) for val in value) + element.set_text(encoded) + return element + + def deserialize(self, element: Optional[XmlElementReader]) -> Optional[List[Any]]: + if element is None: + return None + + text = element.pop_text() + + if text is None: + return None + + return [value for value in text.split()] + + class AttributeSerializer(Serializer): + def __init__( + self, model: Type['pxml.BaseXmlModel'], model_field: pd.fields.ModelField, ctx: Serializer.Context, + ): + assert model_field.sub_fields and len(model_field.sub_fields) == 1 + if issubclass(model_field.type_, pxml.BaseXmlModel): + raise errors.ModelFieldError( + model.__name__, model_field.name, "Inline list value should be of scalar type", + ) + + _, ns, nsmap = self._get_entity_info(model_field) + + name = model_field.alias + + self.attr_name = QName.from_alias(tag=name, ns=ns, nsmap=nsmap, is_attr=True).uri + + def serialize( + self, element: XmlElementWriter, value: Collection[Any], *, encoder: XmlEncoder, + skip_empty: bool = False, + ) -> Optional[XmlElementWriter]: + if value is None or skip_empty and len(value) == 0: + return element + + encoded = " ".join(encoder.encode(val) for val in value) + element.set_attribute(self.attr_name, encoded) + return element + + def deserialize(self, element: Optional[XmlElementReader]) -> Optional[List[Any]]: + if element is None: + return None + + attribute = element.pop_attrib(self.attr_name) + + if attribute is None: + return [] + + return [value for value in attribute.split()] + class ElementSerializer(Serializer): def __init__( self, model: Type['pxml.BaseXmlModel'], model_field: pd.fields.ModelField, ctx: Serializer.Context, @@ -103,10 +176,8 @@ def build( if field_location is Location.ELEMENT: return cls.ElementSerializer(model, model_field, ctx) elif field_location is Location.MISSING: - return cls.ElementSerializer(model, model_field, ctx) + return cls.TextSerializer(model, model_field, ctx) elif field_location is Location.ATTRIBUTE: - raise errors.ModelFieldError( - model.__name__, model_field.name, "attributes of collection type are not supported", - ) + return cls.AttributeSerializer(model, model_field, ctx) else: raise AssertionError("unreachable") diff --git a/tests/test_homogeneous_collections.py b/tests/test_homogeneous_collections.py index cb0b61e..4e94e9b 100644 --- a/tests/test_homogeneous_collections.py +++ b/tests/test_homogeneous_collections.py @@ -121,11 +121,93 @@ class RootModel(BaseXmlModel, tag='model'): assert_xml_equal(actual_xml, xml) -def test_homogeneous_definition_errors(): - with pytest.raises(errors.ModelFieldError): - class TestModel(BaseXmlModel): - attr1: List[int] = attr() +def test_text_list_extraction(): + class RootModel(BaseXmlModel, tag="model"): + values: List[int] + + xml = ''' + 1 2 70 -34 + ''' + + actual_obj = RootModel.from_xml(xml) + expected_obj = RootModel( + values = [1, 2, 70, -34], + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_text_tuple_extraction(): + class RootModel(BaseXmlModel, tag="model"): + values: Tuple[int, ...] + xml = ''' + 1 2 70 -34 + ''' + + actual_obj = RootModel.from_xml(xml) + expected_obj = RootModel( + values=[1, 2, 70, -34], + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_attr_list_extraction(): + class RootModel(BaseXmlModel, tag="model"): + values: List[float] = attr() + + xml = ''' + + ''' + # This will fail if scientific notation is used + # i.e. if 300 is replaced with 3e2 or 300, the deserializer + # will always use the standard notation with the added `.0`. + # While this behaviour fails the tests, it shouldn't + # matter in practice. + + actual_obj = RootModel.from_xml(xml) + expected_obj = RootModel( + values=[3.14, -1.0, 3e2], + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_attr_tuple_extraction(): + class RootModel(BaseXmlModel, tag="model"): + values: Tuple[float, ...] = attr() + + xml = ''' + + ''' + # This will fail if scientific notation is used + # i.e. if 300 is replaced with 3e2 or 300, the deserializer + # will always use the standard notation with the added `.0`. + # While this behaviour fails the tests, it shouldn't + # matter in practice. + + actual_obj = RootModel.from_xml(xml) + expected_obj = RootModel( + values=(3.14, -1.0, 3e2), + ) + + assert actual_obj == expected_obj + + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_homogeneous_definition_errors(): with pytest.raises(errors.ModelFieldError): class TestModel(BaseXmlModel): attr1: List[Tuple[int, ...]] @@ -156,3 +238,10 @@ class TestSubModel(BaseXmlModel): class TestModel(BaseXmlModel): __root__: List[TestSubModel] + + with pytest.raises(errors.ModelFieldError): + class TestSubModel(BaseXmlModel): + attr: int + + class TestModel(BaseXmlModel): + text: List[TestSubModel] diff --git a/tests/test_misc.py b/tests/test_misc.py index cbab977..30e7003 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -46,7 +46,7 @@ class TestSubModel(BaseXmlModel, tag='model'): class TestModel(BaseXmlModel, tag='model'): model: TestSubModel - list: List[TestSubModel] = [] + list: List[TestSubModel] = element(default=[]) tuple: Optional[Tuple[TestSubModel, TestSubModel]] = None attrs: Dict[str, str] = {} wrapped: Optional[str] = wrapped('envelope')