diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7742ade..19127cd 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -63,7 +63,7 @@ repos:
- --multi-line=9
- --project=pydantic_xml
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v0.991
+ rev: v1.0.0
hooks:
- id: mypy
stages:
diff --git a/examples/snippets/model_template.py b/examples/snippets/model_template.py
index b9a4ee8..8b6a0cc 100644
--- a/examples/snippets/model_template.py
+++ b/examples/snippets/model_template.py
@@ -27,7 +27,7 @@ class Config:
class Vehicles(BaseXmlModel, tag='vehicles'):
- items: List[Union[Car, Airplane]]
+ items: List[Union[Car, Airplane]] = element()
# [model-end]
diff --git a/examples/snippets/union_models.py b/examples/snippets/union_models.py
index 791f2da..573a0ac 100644
--- a/examples/snippets/union_models.py
+++ b/examples/snippets/union_models.py
@@ -18,7 +18,7 @@ class MouseEvent(Event, tag='mouse'):
class Log(BaseXmlModel, tag='log'):
- events: List[Union[KeyboardEvent, MouseEvent]]
+ events: List[Union[KeyboardEvent, MouseEvent]] = element()
# [model-end]
diff --git a/examples/snippets/union_primitives.py b/examples/snippets/union_primitives.py
index 434efdc..f4356fa 100644
--- a/examples/snippets/union_primitives.py
+++ b/examples/snippets/union_primitives.py
@@ -1,7 +1,7 @@
import datetime as dt
from typing import List, Optional, Union
-from pydantic_xml import BaseXmlModel, attr
+from pydantic_xml import BaseXmlModel, attr, element
# [model-start]
@@ -11,7 +11,7 @@ class Message(BaseXmlModel, tag='Message'):
class Messages(BaseXmlModel):
- messages: List[Message]
+ messages: List[Message] = element()
# [model-end]
diff --git a/pydantic_xml/serializers/factories/homogeneous.py b/pydantic_xml/serializers/factories/homogeneous.py
index 09b9eac..77177dd 100644
--- a/pydantic_xml/serializers/factories/homogeneous.py
+++ b/pydantic_xml/serializers/factories/homogeneous.py
@@ -1,6 +1,6 @@
import dataclasses as dc
from copy import deepcopy
-from typing import Any, List, Optional, Type
+from typing import Any, Collection, List, Optional, Type
import pydantic as pd
@@ -8,7 +8,7 @@
from pydantic_xml import errors
from pydantic_xml.element import XmlElementReader, XmlElementWriter
from pydantic_xml.serializers.encoder import XmlEncoder
-from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer
+from pydantic_xml.serializers.serializer import Location, PydanticShapeType, Serializer, is_xml_model
from pydantic_xml.utils import QName, merge_nsmaps
@@ -17,6 +17,79 @@ class HomogeneousSerializerFactory:
Homogeneous collection type serializer factory.
"""
+ class TextSerializer(Serializer):
+ def __init__(
+ self, model: Type['pxml.BaseXmlModel'], model_field: pd.fields.ModelField, ctx: Serializer.Context,
+ ):
+ assert model_field.sub_fields and len(model_field.sub_fields) == 1
+ if (
+ is_xml_model(model_field.type_) or
+ issubclass(model_field.type_, tuple)
+ ):
+ raise errors.ModelFieldError(
+ model.__name__, model_field.name, "Inline list value should be of scalar type",
+ )
+
+ def serialize(
+ self, element: XmlElementWriter, value: Collection[Any], *, encoder: XmlEncoder,
+ skip_empty: bool = False,
+ ) -> Optional[XmlElementWriter]:
+ if value is None or skip_empty and len(value) == 0:
+ return element
+
+ encoded = " ".join(encoder.encode(val) for val in value)
+ element.set_text(encoded)
+ return element
+
+ def deserialize(self, element: Optional[XmlElementReader]) -> Optional[List[Any]]:
+ if element is None:
+ return None
+
+ text = element.pop_text()
+
+ if text is None:
+ return None
+
+ return [value for value in text.split()]
+
+ class AttributeSerializer(Serializer):
+ def __init__(
+ self, model: Type['pxml.BaseXmlModel'], model_field: pd.fields.ModelField, ctx: Serializer.Context,
+ ):
+ assert model_field.sub_fields and len(model_field.sub_fields) == 1
+ if issubclass(model_field.type_, pxml.BaseXmlModel):
+ raise errors.ModelFieldError(
+ model.__name__, model_field.name, "Inline list value should be of scalar type",
+ )
+
+ _, ns, nsmap = self._get_entity_info(model_field)
+
+ name = model_field.alias
+
+ self.attr_name = QName.from_alias(tag=name, ns=ns, nsmap=nsmap, is_attr=True).uri
+
+ def serialize(
+ self, element: XmlElementWriter, value: Collection[Any], *, encoder: XmlEncoder,
+ skip_empty: bool = False,
+ ) -> Optional[XmlElementWriter]:
+ if value is None or skip_empty and len(value) == 0:
+ return element
+
+ encoded = " ".join(encoder.encode(val) for val in value)
+ element.set_attribute(self.attr_name, encoded)
+ return element
+
+ def deserialize(self, element: Optional[XmlElementReader]) -> Optional[List[Any]]:
+ if element is None:
+ return None
+
+ attribute = element.pop_attrib(self.attr_name)
+
+ if attribute is None:
+ return []
+
+ return [value for value in attribute.split()]
+
class ElementSerializer(Serializer):
def __init__(
self, model: Type['pxml.BaseXmlModel'], model_field: pd.fields.ModelField, ctx: Serializer.Context,
@@ -103,10 +176,8 @@ def build(
if field_location is Location.ELEMENT:
return cls.ElementSerializer(model, model_field, ctx)
elif field_location is Location.MISSING:
- return cls.ElementSerializer(model, model_field, ctx)
+ return cls.TextSerializer(model, model_field, ctx)
elif field_location is Location.ATTRIBUTE:
- raise errors.ModelFieldError(
- model.__name__, model_field.name, "attributes of collection type are not supported",
- )
+ return cls.AttributeSerializer(model, model_field, ctx)
else:
raise AssertionError("unreachable")
diff --git a/tests/test_homogeneous_collections.py b/tests/test_homogeneous_collections.py
index cb0b61e..4e94e9b 100644
--- a/tests/test_homogeneous_collections.py
+++ b/tests/test_homogeneous_collections.py
@@ -121,11 +121,93 @@ class RootModel(BaseXmlModel, tag='model'):
assert_xml_equal(actual_xml, xml)
-def test_homogeneous_definition_errors():
- with pytest.raises(errors.ModelFieldError):
- class TestModel(BaseXmlModel):
- attr1: List[int] = attr()
+def test_text_list_extraction():
+ class RootModel(BaseXmlModel, tag="model"):
+ values: List[int]
+
+ xml = '''
+ 1 2 70 -34
+ '''
+
+ actual_obj = RootModel.from_xml(xml)
+ expected_obj = RootModel(
+ values = [1, 2, 70, -34],
+ )
+
+ assert actual_obj == expected_obj
+
+ actual_xml = actual_obj.to_xml()
+ assert_xml_equal(actual_xml, xml)
+
+
+def test_text_tuple_extraction():
+ class RootModel(BaseXmlModel, tag="model"):
+ values: Tuple[int, ...]
+ xml = '''
+ 1 2 70 -34
+ '''
+
+ actual_obj = RootModel.from_xml(xml)
+ expected_obj = RootModel(
+ values=[1, 2, 70, -34],
+ )
+
+ assert actual_obj == expected_obj
+
+ actual_xml = actual_obj.to_xml()
+ assert_xml_equal(actual_xml, xml)
+
+
+def test_attr_list_extraction():
+ class RootModel(BaseXmlModel, tag="model"):
+ values: List[float] = attr()
+
+ xml = '''
+
+ '''
+ # This will fail if scientific notation is used
+ # i.e. if 300 is replaced with 3e2 or 300, the deserializer
+ # will always use the standard notation with the added `.0`.
+ # While this behaviour fails the tests, it shouldn't
+ # matter in practice.
+
+ actual_obj = RootModel.from_xml(xml)
+ expected_obj = RootModel(
+ values=[3.14, -1.0, 3e2],
+ )
+
+ assert actual_obj == expected_obj
+
+ actual_xml = actual_obj.to_xml()
+ assert_xml_equal(actual_xml, xml)
+
+
+def test_attr_tuple_extraction():
+ class RootModel(BaseXmlModel, tag="model"):
+ values: Tuple[float, ...] = attr()
+
+ xml = '''
+
+ '''
+ # This will fail if scientific notation is used
+ # i.e. if 300 is replaced with 3e2 or 300, the deserializer
+ # will always use the standard notation with the added `.0`.
+ # While this behaviour fails the tests, it shouldn't
+ # matter in practice.
+
+ actual_obj = RootModel.from_xml(xml)
+ expected_obj = RootModel(
+ values=(3.14, -1.0, 3e2),
+ )
+
+ assert actual_obj == expected_obj
+
+ actual_xml = actual_obj.to_xml()
+ assert_xml_equal(actual_xml, xml)
+
+
+def test_homogeneous_definition_errors():
with pytest.raises(errors.ModelFieldError):
class TestModel(BaseXmlModel):
attr1: List[Tuple[int, ...]]
@@ -156,3 +238,10 @@ class TestSubModel(BaseXmlModel):
class TestModel(BaseXmlModel):
__root__: List[TestSubModel]
+
+ with pytest.raises(errors.ModelFieldError):
+ class TestSubModel(BaseXmlModel):
+ attr: int
+
+ class TestModel(BaseXmlModel):
+ text: List[TestSubModel]
diff --git a/tests/test_misc.py b/tests/test_misc.py
index cbab977..30e7003 100644
--- a/tests/test_misc.py
+++ b/tests/test_misc.py
@@ -46,7 +46,7 @@ class TestSubModel(BaseXmlModel, tag='model'):
class TestModel(BaseXmlModel, tag='model'):
model: TestSubModel
- list: List[TestSubModel] = []
+ list: List[TestSubModel] = element(default=[])
tuple: Optional[Tuple[TestSubModel, TestSubModel]] = None
attrs: Dict[str, str] = {}
wrapped: Optional[str] = wrapped('envelope')