Skip to content

Commit 17dbce4

Browse files
authored
Merge branch 'main' into dataclass-container-strings
2 parents 1d3d906 + 1db13f2 commit 17dbce4

File tree

5 files changed

+36
-13
lines changed

5 files changed

+36
-13
lines changed

arraycontext/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"""
3030

3131
import sys
32-
from .context import ArrayContext
32+
from .context import ArrayContext, DeviceArray, DeviceScalar
3333

3434
from .transform_metadata import (CommonSubexpressionTag,
3535
ElementwiseMapKernelTag)
@@ -74,7 +74,7 @@
7474

7575

7676
__all__ = (
77-
"ArrayContext",
77+
"ArrayContext", "DeviceScalar", "DeviceArray",
7878

7979
"CommonSubexpressionTag",
8080
"ElementwiseMapKernelTag",

arraycontext/container/traversal.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767

6868
import numpy as np
6969

70-
from arraycontext.context import ArrayContext
70+
from arraycontext.context import ArrayContext, DeviceArray
7171
from arraycontext.container import (
7272
ContainerT, ArrayOrContainerT, NotAnArrayContainerError,
7373
serialize_container, deserialize_container)
@@ -355,7 +355,7 @@ def rec(keys: Tuple[Union[str, int], ...],
355355
def map_reduce_array_container(
356356
reduce_func: Callable[[Iterable[Any]], Any],
357357
map_func: Callable[[Any], Any],
358-
ary: ArrayOrContainerT) -> Any:
358+
ary: ArrayOrContainerT) -> "DeviceArray":
359359
"""Perform a map-reduce over array containers.
360360
361361
:param reduce_func: callable used to reduce over the components of *ary*
@@ -378,7 +378,7 @@ def map_reduce_array_container(
378378
def multimap_reduce_array_container(
379379
reduce_func: Callable[[Iterable[Any]], Any],
380380
map_func: Callable[..., Any],
381-
*args: Any) -> Any:
381+
*args: Any) -> "DeviceArray":
382382
r"""Perform a map-reduce over multiple array containers.
383383
384384
:param reduce_func: callable used to reduce over the components of any
@@ -401,7 +401,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401401
def rec_map_reduce_array_container(
402402
reduce_func: Callable[[Iterable[Any]], Any],
403403
map_func: Callable[[Any], Any],
404-
ary: ArrayOrContainerT) -> Any:
404+
ary: ArrayOrContainerT) -> "DeviceArray":
405405
"""Perform a map-reduce over array containers recursively.
406406
407407
:param reduce_func: callable used to reduce over the components of *ary*
@@ -455,7 +455,7 @@ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
455455
def rec_multimap_reduce_array_container(
456456
reduce_func: Callable[[Iterable[Any]], Any],
457457
map_func: Callable[..., Any],
458-
*args: Any) -> Any:
458+
*args: Any) -> "DeviceArray":
459459
r"""Perform a map-reduce over multiple array containers recursively.
460460
461461
:param reduce_func: callable used to reduce over the components of any

arraycontext/context.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@
7474
---------------------------------
7575
7676
.. currentmodule:: arraycontext
77+
78+
.. class:: DeviceArray
79+
80+
A (type alias for an) array type supported by the :class:`ArrayContext`
81+
meant to aid in typing annotations. For a explicit list of supported types
82+
see :attr:`ArrayContext.array_types`.
83+
84+
.. class:: DeviceScalar
85+
86+
A (type alias for a) scalar type supported by the :class:`ArrayContext`
87+
meant to aid in typing annotations, e.g. for reductions. In :mod:`numpy`
88+
terminology, this is just an array with a shape of ``()``.
89+
7790
.. autoclass:: ArrayContext
7891
"""
7992

@@ -110,6 +123,10 @@
110123
from pytools.tag import Tag
111124

112125

126+
DeviceArray = Any
127+
DeviceScalar = Any
128+
129+
113130
# {{{ ArrayContext
114131

115132
class ArrayContext(ABC):

arraycontext/impl/pytato/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
from typing import Any, Dict, Set, Tuple, Mapping
27-
from pytato.array import SizeParam, Placeholder
27+
from pytato.array import SizeParam, Placeholder, make_placeholder
2828
from pytato.array import Array, DataWrapper, DictOfNamedArrays
2929
from pytato.transform import CopyMapper
3030
from pytools import UniqueNameGenerator
@@ -52,11 +52,12 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
5252
# Normalizing names so that we more arrays can have the normalized DAG.
5353
name = self.vng("_actx_dw")
5454
self.bound_arguments[name] = expr.data
55-
return Placeholder(name=name,
56-
shape=tuple(self.rec(s) if isinstance(s, Array) else s
57-
for s in expr.shape),
58-
dtype=expr.dtype,
59-
tags=expr.tags)
55+
return make_placeholder(
56+
name=name,
57+
shape=tuple(self.rec(s) if isinstance(s, Array) else s
58+
for s in expr.shape),
59+
dtype=expr.dtype,
60+
tags=expr.tags)
6061

6162
def map_size_param(self, expr: SizeParam) -> Array:
6263
raise NotImplementedError

doc/conf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
version = ".".join(str(x) for x in ver_dic["VERSION"])
1515
release = ver_dic["VERSION_TEXT"]
1616

17+
autodoc_type_aliases = {
18+
"DeviceScalar": "arraycontext.DeviceScalar",
19+
"DeviceArray": "arraycontext.DeviceArray",
20+
}
21+
1722
intersphinx_mapping = {
1823
"https://docs.python.org/3/": None,
1924
"https://numpy.org/doc/stable/": None,

0 commit comments

Comments
 (0)