3030THE SOFTWARE.
3131"""
3232
33- from typing import Union , get_args
33+ from typing import Tuple , Union , get_args
3434try :
3535 # NOTE: only available in python >= 3.8
3636 from typing import get_origin
3737except ImportError :
3838 from typing_extensions import get_origin
3939
40- from dataclasses import fields
40+ from dataclasses import Field , is_dataclass , fields
4141from arraycontext .container import is_array_container_type
4242
4343
4444# {{{ dataclass containers
4545
46+ def is_array_type (tp : type ) -> bool :
47+ from arraycontext import Array
48+ return tp is Array or is_array_container_type (tp )
49+
50+
4651def dataclass_array_container (cls : type ) -> type :
4752 """A class decorator that makes the class to which it is applied an
4853 :class:`ArrayContainer` by registering appropriate implementations of
@@ -51,24 +56,37 @@ def dataclass_array_container(cls: type) -> type:
5156
5257 Attributes that are not array containers are allowed. In order to decide
5358 whether an attribute is an array container, the declared attribute type
54- is checked by the criteria from :func:`is_array_container_type`.
59+ is checked by the criteria from :func:`is_array_container_type`. This
60+ includes some support for type annotations:
61+
62+ * a :class:`typing.Union` of array containers is considered an array container.
63+ * other type annotations, e.g. :class:`typing.Optional`, are not considered
64+ array containers, even if they wrap one.
5565 """
56- from dataclasses import is_dataclass , Field
66+
5767 assert is_dataclass (cls )
5868
5969 def is_array_field (f : Field ) -> bool :
60- from arraycontext import Array
70+ # NOTE: unions of array containers are treated separately to handle
71+ # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
72+ # they can work seamlessly with arithmetic and traversal.
73+ #
74+ # `Optional[ArrayContainer]` is not allowed, since `None` is not
75+ # handled by `with_container_arithmetic`, which is the common case
76+ # for current container usage. Other type annotations, e.g.
77+ # `Tuple[Container, Container]`, are also not allowed, as they do not
78+ # work with `with_container_arithmetic`.
79+ #
80+ # This is not set in stone, but mostly driven by current usage!
6181
6282 origin = get_origin (f .type )
6383 if origin is Union :
64- if not all (
65- arg is Array or is_array_container_type ( arg )
66- for arg in get_args ( f . type )) :
84+ if all (is_array_type ( arg ) for arg in get_args ( f . type )):
85+ return True
86+ else :
6787 raise TypeError (
6888 f"Field '{ f .name } ' union contains non-array container "
6989 "arguments. All arguments must be array containers." )
70- else :
71- return True
7290
7391 if __debug__ :
7492 if not f .init :
@@ -79,8 +97,12 @@ def is_array_field(f: Field) -> bool:
7997 raise TypeError (
8098 f"string annotation on field '{ f .name } ' not supported" )
8199
82- from typing import _SpecialForm
83- if isinstance (f .type , _SpecialForm ):
100+ # NOTE:
101+ # * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
102+ # * `_SpecialForm` catches `Any`, `Literal`, etc.
103+ from typing import ( # type: ignore[attr-defined]
104+ _BaseGenericAlias , _SpecialForm )
105+ if isinstance (f .type , (_BaseGenericAlias , _SpecialForm )):
84106 # NOTE: anything except a Union is not allowed
85107 raise TypeError (
86108 f"typing annotation not supported on field '{ f .name } ': "
@@ -91,7 +113,7 @@ def is_array_field(f: Field) -> bool:
91113 f"field '{ f .name } ' not an instance of 'type': "
92114 f"'{ f .type !r} '" )
93115
94- return f . type is Array or is_array_container_type (f .type )
116+ return is_array_type (f .type )
95117
96118 from pytools import partition
97119 array_fields , non_array_fields = partition (is_array_field , fields (cls ))
@@ -100,6 +122,27 @@ def is_array_field(f: Field) -> bool:
100122 raise ValueError (f"'{ cls } ' must have fields with array container type "
101123 "in order to use the 'dataclass_array_container' decorator" )
102124
125+ return inject_dataclass_serialization (cls , array_fields , non_array_fields )
126+
127+
128+ def inject_dataclass_serialization (
129+ cls : type ,
130+ array_fields : Tuple [Field , ...],
131+ non_array_fields : Tuple [Field , ...]) -> type :
132+ """Implements :func:`~arraycontext.serialize_container` and
133+ :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
134+
135+ This function modifies *cls* in place, so the returned value is the same
136+ object with additional functionality.
137+
138+ :arg array_fields: fields of the given dataclass *cls* which are considered
139+ array containers and should be serialized.
140+ :arg non_array_fields: remaining fields of the dataclass *cls* which are
141+ copied over from the template array in deserialization.
142+ """
143+
144+ assert is_dataclass (cls )
145+
103146 serialize_expr = ", " .join (
104147 f"({ f .name !r} , ary.{ f .name } )" for f in array_fields )
105148 template_kwargs = ", " .join (
0 commit comments