Skip to content

Commit 2120e4a

Browse files
committed
Container serialization: iterable -> sequence, plus type aliases
1 parent c2c1d77 commit 2120e4a

File tree

4 files changed

+50
-19
lines changed

4 files changed

+50
-19
lines changed

arraycontext/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
ArrayContainer,
3333
ArrayContainerT,
3434
NotAnArrayContainerError,
35+
SerializationKey,
36+
SerializedContainer,
3537
deserialize_container,
3638
get_container_context_opt,
3739
get_container_context_recursively,
@@ -113,6 +115,8 @@
113115
"PytestPyOpenCLArrayContextFactory",
114116
"Scalar",
115117
"ScalarLike",
118+
"SerializationKey",
119+
"SerializedContainer",
116120
"dataclass_array_container",
117121
"deserialize_container",
118122
"flat_size_and_dtype",
@@ -148,7 +152,7 @@
148152
"with_array_context",
149153
"with_container_arithmetic",
150154
"with_container_arithmetic"
151-
)
155+
)
152156

153157

154158
# {{{ deprecation handling

arraycontext/container/__init__.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
1313
Serialization/deserialization
1414
-----------------------------
15+
16+
.. autoclass:: SerializationKey
17+
.. autoclass:: SerializedContainer
1518
.. autofunction:: is_array_container_type
1619
.. autofunction:: serialize_container
1720
.. autofunction:: deserialize_container
@@ -39,6 +42,14 @@
3942
.. class:: ArrayOrContainerT
4043
4144
:canonical: arraycontext.ArrayOrContainerT
45+
46+
.. class:: SerializationKey
47+
48+
:canonical: arraycontext.SerializationKey
49+
50+
.. class:: SerializedContainer
51+
52+
:canonical: arraycontext.SerializedContainer
4253
"""
4354

4455
from __future__ import annotations
@@ -69,12 +80,23 @@
6980
"""
7081

7182
from functools import singledispatch
72-
from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar
83+
from typing import (
84+
TYPE_CHECKING,
85+
Any,
86+
Hashable,
87+
Iterable,
88+
Optional,
89+
Protocol,
90+
Sequence,
91+
Tuple,
92+
TypeVar,
93+
)
7394

7495
# For use in singledispatch type annotations, because sphinx can't figure out
7596
# what 'np' is.
7697
import numpy
7798
import numpy as np
99+
from typing_extensions import TypeAlias
78100

79101
from arraycontext.context import ArrayContext
80102

@@ -142,23 +164,27 @@ class NotAnArrayContainerError(TypeError):
142164
""":class:`TypeError` subclass raised when an array container is expected."""
143165

144166

167+
SerializationKey: TypeAlias = Hashable
168+
SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]]
169+
170+
145171
@singledispatch
146172
def serialize_container(
147-
ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]:
148-
r"""Serialize the array container into an iterable over its components.
173+
ary: ArrayContainer) -> SerializedContainer:
174+
r"""Serialize the array container into a sequence over its components.
149175
150176
The order of the components and their identifiers are entirely under
151177
the control of the container class. However, the order is required to be
152178
deterministic, i.e. two calls to :func:`serialize_container` on
153179
array containers of the same types with the same number of
154-
sub-arrays must result in an iterable with the keys in the same
180+
sub-arrays must result in a sequence with the keys in the same
155181
order.
156182
157183
If *ary* is mutable, the serialization function is not required to ensure
158184
that the serialization result reflects the array state at the time of the
159185
call to :func:`serialize_container`.
160186
161-
:returns: an :class:`Iterable` of 2-tuples where the first
187+
:returns: an :class:`Sequence` of 2-tuples where the first
162188
entry is an identifier for the component and the second entry
163189
is an array-like component of the :class:`ArrayContainer`.
164190
Components can themselves be :class:`ArrayContainer`\ s, allowing
@@ -172,13 +198,13 @@ def serialize_container(
172198
@singledispatch
173199
def deserialize_container(
174200
template: ArrayContainerT,
175-
iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT:
176-
"""Deserialize an iterable into an array container.
201+
serialized: SerializedContainer) -> ArrayContainerT:
202+
"""Deserialize a sequence into an array container following a *template*.
177203
178204
:param template: an instance of an existing object that
179205
can be used to aid in the deserialization. For a similar choice
180206
see :attr:`~numpy.class.__array_finalize__`.
181-
:param iterable: an iterable that mirrors the output of
207+
:param serialized: a sequence that mirrors the output of
182208
:meth:`serialize_container`.
183209
"""
184210
raise NotAnArrayContainerError(
@@ -242,7 +268,7 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]:
242268

243269
@serialize_container.register(np.ndarray)
244270
def _serialize_ndarray_container(
245-
ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]:
271+
ary: numpy.ndarray) -> SerializedContainer:
246272
if ary.dtype.char != "O":
247273
raise NotAnArrayContainerError(
248274
f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'")
@@ -256,20 +282,20 @@ def _serialize_ndarray_container(
256282
for j in range(ary.shape[1])
257283
]
258284
else:
259-
return np.ndenumerate(ary)
285+
return list(np.ndenumerate(ary))
260286

261287

262288
@deserialize_container.register(np.ndarray)
263289
# https://github.com/python/mypy/issues/13040
264290
def _deserialize_ndarray_container( # type: ignore[misc]
265291
template: numpy.ndarray,
266-
iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray:
292+
serialized: SerializedContainer) -> numpy.ndarray:
267293
# disallow subclasses
268294
assert type(template) is np.ndarray
269295
assert template.dtype.char == "O"
270296

271297
result = type(template)(template.shape, dtype=object)
272-
for i, subary in iterable:
298+
for i, subary in serialized:
273299
result[i] = subary
274300

275301
return result

arraycontext/container/traversal.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from arraycontext.container import (
7878
ArrayContainer,
7979
NotAnArrayContainerError,
80+
SerializationKey,
8081
deserialize_container,
8182
get_container_context_recursively_opt,
8283
serialize_container,
@@ -373,12 +374,9 @@ def wrapper(*args: Any) -> Any:
373374

374375
# {{{ keyed array container traversal
375376

376-
KeyType = Any
377-
378-
379377
def keyed_map_array_container(
380378
f: Callable[
381-
[KeyType, ArrayOrContainer],
379+
[SerializationKey, ArrayOrContainer],
382380
ArrayOrContainer],
383381
ary: ArrayOrContainer) -> ArrayOrContainer:
384382
r"""Applies *f* to all components of an :class:`ArrayContainer`.
@@ -403,7 +401,7 @@ def keyed_map_array_container(
403401

404402

405403
def rec_keyed_map_array_container(
406-
f: Callable[[Tuple[KeyType, ...], ArrayT], ArrayT],
404+
f: Callable[[Tuple[SerializationKey, ...], ArrayT], ArrayT],
407405
ary: ArrayOrContainer) -> ArrayOrContainer:
408406
"""
409407
Works similarly to :func:`rec_map_array_container`, except that *f* also
@@ -412,7 +410,7 @@ def rec_keyed_map_array_container(
412410
the current array.
413411
"""
414412

415-
def rec(keys: Tuple[Union[str, int], ...],
413+
def rec(keys: Tuple[SerializationKey, ...],
416414
_ary: ArrayOrContainerT) -> ArrayOrContainerT:
417415
try:
418416
iterable = serialize_container(_ary)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ dependencies = [
3333
"immutabledict>=4.1",
3434
"numpy",
3535
"pytools>=2024.1.3",
36+
37+
# for TypeAlias
38+
"typing-extensions>=4; python_version<'3.10'",
3639
]
3740

3841
[project.optional-dependencies]

0 commit comments

Comments
 (0)