Skip to content

Commit 955820f

Browse files
committed
Rework dataclass array container arithmetic
- Deprecate automatic broadcasting of array context arrays - Introduce Bcast as an object to represent broadcast rules - Warn about uses of numpy array broadcasting, deprecated earlier - Clarify documentation, warning wording
1 parent 58acd1f commit 955820f

File tree

4 files changed

+185
-84
lines changed

4 files changed

+185
-84
lines changed

arraycontext/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
register_multivector_as_array_container,
4444
serialize_container,
4545
)
46-
from .container.arithmetic import with_container_arithmetic
46+
from .container.arithmetic import (
47+
with_container_arithmetic,
48+
)
4749
from .container.dataclass import dataclass_array_container
4850
from .container.traversal import (
4951
flat_size_and_dtype,
@@ -151,7 +153,6 @@
151153
"unflatten",
152154
"with_array_context",
153155
"with_container_arithmetic",
154-
"with_container_arithmetic"
155156
)
156157

157158

arraycontext/container/arithmetic.py

+159-43
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from __future__ import annotations
33

44

5-
"""
5+
__doc__ = """
66
.. currentmodule:: arraycontext
7+
78
.. autofunction:: with_container_arithmetic
89
"""
910

10-
import enum
11-
1211

1312
__copyright__ = """
1413
Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -34,7 +33,9 @@
3433
THE SOFTWARE.
3534
"""
3635

36+
import enum
3737
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
38+
from warnings import warn
3839

3940
import numpy as np
4041

@@ -99,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
99100

100101

101102
def _format_binary_op_str(op_str: str,
102-
arg1: Union[Tuple[str, ...], str],
103-
arg2: Union[Tuple[str, ...], str]) -> str:
103+
arg1: Union[Tuple[str, str], str],
104+
arg2: Union[Tuple[str, str], str]) -> str:
104105
if isinstance(arg1, tuple) and isinstance(arg2, tuple):
105106
import sys
106107
if sys.version_info >= (3, 10):
@@ -127,6 +128,36 @@ def _format_binary_op_str(op_str: str,
127128
return op_str.format(arg1, arg2)
128129

129130

131+
class NumpyObjectArrayMetaclass(type):
132+
def __instancecheck__(cls, instance: Any) -> bool:
133+
return isinstance(instance, np.ndarray) and instance.dtype == object
134+
135+
136+
class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass):
137+
pass
138+
139+
140+
class ComplainingNumpyNonObjectArrayMetaclass(type):
141+
def __instancecheck__(cls, instance: Any) -> bool:
142+
if isinstance(instance, np.ndarray) and instance.dtype != object:
143+
# Example usage site:
144+
# https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149
145+
# where normal is passed in by test_lfr_flux as a 'custom-made'
146+
# numpy array of dtype float64.
147+
warn(
148+
"Broadcasting container against non-object numpy array. "
149+
"This was never documented to work and will now stop working in "
150+
"2025. Convert the array to an object array to preserve the "
151+
"current semantics.", DeprecationWarning, stacklevel=3)
152+
return True
153+
else:
154+
return False
155+
156+
157+
class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMetaclass):
158+
pass
159+
160+
130161
def with_container_arithmetic(
131162
*,
132163
bcast_number: bool = True,
@@ -146,22 +177,16 @@ def with_container_arithmetic(
146177
147178
:arg bcast_number: If *True*, numbers broadcast over the container
148179
(with the container as the 'outer' structure).
149-
:arg _bcast_actx_array_type: If *True*, instances of base array types of the
150-
container's array context are broadcasted over the container. Can be
151-
*True* only if the container has *_cls_has_array_context_attr* set.
152-
Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set,
153-
else *False*.
154-
:arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over
155-
the container. (with the container as the 'inner' structure)
156-
:arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
157-
over the container. (with the container as the 'inner' structure)
158-
If this is set to *True*, *bcast_obj_array* must also be *True*.
180+
:arg bcast_obj_array: If *True*, this container will be broadcast
181+
across :mod:`numpy` object arrays
182+
(with the object array as the 'outer' structure).
183+
Add :class:`numpy.ndarray` to *bcast_container_types* to achieve
184+
the 'reverse' broadcasting.
159185
:arg bcast_container_types: A sequence of container types that will broadcast
160-
over this container (with this container as the 'outer' structure).
186+
across this container, with this container as the 'outer' structure.
161187
:class:`numpy.ndarray` is permitted to be part of this sequence to
162-
indicate that, in such broadcasting situations, this container should
163-
be the 'outer' structure. In this case, *bcast_obj_array*
164-
(and consequently *bcast_numpy_array*) must be *False*.
188+
indicate that object arrays (and *only* object arrays) will be broadcasat.
189+
In this case, *bcast_obj_array* must be *False*.
165190
:arg arithmetic: Implement the conventional arithmetic operators, including
166191
``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
167192
:func:`abs`.
@@ -203,6 +228,17 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
203228
should nest "outside" :func:dataclass_array_container`.
204229
"""
205230

231+
# Hard-won design lessons:
232+
#
233+
# - Anything that special-cases np.ndarray by type is broken by design because:
234+
# - np.ndarray is an array context array.
235+
# - numpy object arrays can be array containers.
236+
# Using NumpyObjectArray and NumpyNonObjectArray *may* be better?
237+
# They're new, so there is no operational experience with them.
238+
#
239+
# - Broadcast rules are hard to change once established, particularly
240+
# because one cannot grep for their use.
241+
206242
# {{{ handle inputs
207243

208244
if bcast_obj_array is None:
@@ -212,9 +248,8 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
212248
raise TypeError("rel_comparison must be specified")
213249

214250
if bcast_numpy_array:
215-
from warnings import warn
216251
warn("'bcast_numpy_array=True' is deprecated and will be unsupported"
217-
" from December 2021", DeprecationWarning, stacklevel=2)
252+
" from 2025.", DeprecationWarning, stacklevel=2)
218253

219254
if _bcast_actx_array_type:
220255
raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
@@ -231,7 +266,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
231266

232267
if bcast_numpy_array:
233268
def numpy_pred(name: str) -> str:
234-
return f"isinstance({name}, np.ndarray)"
269+
return f"is_numpy_array({name})"
235270
elif bcast_obj_array:
236271
def numpy_pred(name: str) -> str:
237272
return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
@@ -241,12 +276,21 @@ def numpy_pred(name: str) -> str:
241276

242277
if bcast_container_types is None:
243278
bcast_container_types = ()
244-
bcast_container_types_count = len(bcast_container_types)
245279

246280
if np.ndarray in bcast_container_types and bcast_obj_array:
247281
raise ValueError("If numpy.ndarray is part of bcast_container_types, "
248282
"bcast_obj_array must be False.")
249283

284+
numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray]
285+
bcast_container_types = tuple(
286+
new_ct
287+
for old_ct in bcast_container_types
288+
for new_ct in
289+
(numpy_check_types
290+
if old_ct is np.ndarray
291+
else [old_ct])
292+
)
293+
250294
desired_op_classes = set()
251295
if arithmetic:
252296
desired_op_classes.add(_OpClass.ARITHMETIC)
@@ -264,19 +308,24 @@ def numpy_pred(name: str) -> str:
264308
# }}}
265309

266310
def wrap(cls: Any) -> Any:
267-
cls_has_array_context_attr: bool | None = \
268-
_cls_has_array_context_attr
269-
bcast_actx_array_type: bool | None = \
270-
_bcast_actx_array_type
311+
if not hasattr(cls, "__array_ufunc__"):
312+
warn(f"{cls} does not have __array_ufunc__ set. "
313+
"This will cause numpy to attempt broadcasting, in a way that "
314+
"is likely undesired. "
315+
f"To avoid this, set __array_ufunc__ = None in {cls}.",
316+
stacklevel=2)
317+
318+
cls_has_array_context_attr: bool | None = _cls_has_array_context_attr
319+
bcast_actx_array_type: bool | None = _bcast_actx_array_type
271320

272321
if cls_has_array_context_attr is None:
273322
if hasattr(cls, "array_context"):
274323
raise TypeError(
275324
f"{cls} has an 'array_context' attribute, but it does not "
276325
"set '_cls_has_array_context_attr' to *True* when calling "
277326
"with_container_arithmetic. This is being interpreted "
278-
"as 'array_context' being permitted to fail with an exception, "
279-
"which is no longer allowed. "
327+
"as '.array_context' being permitted to fail "
328+
"with an exception, which is no longer allowed. "
280329
f"If {cls.__name__}.array_context will not fail, pass "
281330
"'_cls_has_array_context_attr=True'. "
282331
"If you do not want container arithmetic to make "
@@ -294,6 +343,30 @@ def wrap(cls: Any) -> Any:
294343
raise TypeError("_bcast_actx_array_type can be True only if "
295344
"_cls_has_array_context_attr is set.")
296345

346+
if bcast_actx_array_type:
347+
if _bcast_actx_array_type:
348+
warn(
349+
f"Broadcasting array context array types across {cls} "
350+
"has been explicitly "
351+
"enabled. As of 2025, this will stop working. "
352+
"There is no replacement as of right now. "
353+
"See the discussion in "
354+
"https://github.com/inducer/arraycontext/pull/190. "
355+
"To opt out now (and avoid this warning), "
356+
"pass _bcast_actx_array_type=False. ",
357+
DeprecationWarning, stacklevel=2)
358+
else:
359+
warn(
360+
f"Broadcasting array context array types across {cls} "
361+
"has been implicitly "
362+
"enabled. As of 2025, this will no longer work. "
363+
"There is no replacement as of right now. "
364+
"See the discussion in "
365+
"https://github.com/inducer/arraycontext/pull/190. "
366+
"To opt out now (and avoid this warning), "
367+
"pass _bcast_actx_array_type=False.",
368+
DeprecationWarning, stacklevel=2)
369+
297370
if (not hasattr(cls, "_serialize_init_arrays_code")
298371
or not hasattr(cls, "_deserialize_init_arrays_code")):
299372
raise TypeError(f"class '{cls.__name__}' must provide serialization "
@@ -304,7 +377,7 @@ def wrap(cls: Any) -> Any:
304377

305378
from pytools.codegen import CodeGenerator, Indentation
306379
gen = CodeGenerator()
307-
gen("""
380+
gen(f"""
308381
from numbers import Number
309382
import numpy as np
310383
from arraycontext import ArrayContainer
@@ -315,6 +388,24 @@ def _raise_if_actx_none(actx):
315388
raise ValueError("array containers with frozen arrays "
316389
"cannot be operated upon")
317390
return actx
391+
392+
def is_numpy_array(arg):
393+
if isinstance(arg, np.ndarray):
394+
if arg.dtype != "O":
395+
warn("Operand is a non-object numpy array, "
396+
"and the broadcasting behavior of this array container "
397+
"({cls}) "
398+
"is influenced by this because of its use of "
399+
"the deprecated bcast_numpy_array. This broadcasting "
400+
"behavior will change in 2025. If you would like the "
401+
"broadcasting behavior to stay the same, make sure "
402+
"to convert the passed numpy array to an "
403+
"object array.",
404+
DeprecationWarning, stacklevel=3)
405+
return True
406+
else:
407+
return False
408+
318409
""")
319410
gen("")
320411

@@ -323,7 +414,7 @@ def _raise_if_actx_none(actx):
323414
gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
324415
gen("")
325416
outer_bcast_type_names = tuple([
326-
f"_bctype{i}" for i in range(bcast_container_types_count)
417+
f"_bctype{i}" for i in range(len(bcast_container_types))
327418
])
328419
if bcast_number:
329420
outer_bcast_type_names += ("Number",)
@@ -384,20 +475,25 @@ def {fname}(arg1):
384475

385476
continue
386477

387-
# {{{ "forward" binary operators
388-
389478
zip_init_args = cls._deserialize_init_arrays_code("arg1", {
390479
same_key(key_arg1, key_arg2):
391480
_format_binary_op_str(op_str, expr_arg1, expr_arg2)
392481
for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
393482
cls._serialize_init_arrays_code("arg1").items(),
394483
cls._serialize_init_arrays_code("arg2").items())
395484
})
396-
bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", {
485+
bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
397486
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
398487
for key_arg1, expr_arg1 in
399488
cls._serialize_init_arrays_code("arg1").items()
400489
})
490+
bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", {
491+
key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
492+
for key_arg2, expr_arg2 in
493+
cls._serialize_init_arrays_code("arg2").items()
494+
})
495+
496+
# {{{ "forward" binary operators
401497

402498
gen(f"def {fname}(arg1, arg2):")
403499
with Indentation(gen):
@@ -424,7 +520,7 @@ def {fname}(arg1):
424520

425521
if bcast_actx_array_type:
426522
if __debug__:
427-
bcast_actx_ary_types = (
523+
bcast_actx_ary_types: tuple[str, ...] = (
428524
"*_raise_if_actx_none("
429525
"arg1.array_context).array_types",)
430526
else:
@@ -444,7 +540,19 @@ def {fname}(arg1):
444540
if isinstance(arg2,
445541
{tup_str(outer_bcast_type_names
446542
+ bcast_actx_ary_types)}):
447-
return cls({bcast_same_cls_init_args})
543+
if __debug__:
544+
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
545+
warn("Broadcasting {cls} over array "
546+
f"context array type {{type(arg2)}} is deprecated "
547+
"and will no longer work in 2025. "
548+
"There is no replacement as of right now. "
549+
"See the discussion in "
550+
"https://github.com/inducer/arraycontext/"
551+
"pull/190. ",
552+
DeprecationWarning, stacklevel=2)
553+
554+
return cls({bcast_init_args_arg1_is_outer})
555+
448556
return NotImplemented
449557
""")
450558
gen(f"cls.__{dunder_name}__ = {fname}")
@@ -456,12 +564,6 @@ def {fname}(arg1):
456564

457565
if reversible:
458566
fname = f"_{cls.__name__.lower()}_r{dunder_name}"
459-
bcast_init_args = cls._deserialize_init_arrays_code("arg2", {
460-
key_arg2: _format_binary_op_str(
461-
op_str, "arg1", expr_arg2)
462-
for key_arg2, expr_arg2 in
463-
cls._serialize_init_arrays_code("arg2").items()
464-
})
465567

466568
if bcast_actx_array_type:
467569
if __debug__:
@@ -487,7 +589,21 @@ def {fname}(arg2, arg1):
487589
if isinstance(arg1,
488590
{tup_str(outer_bcast_type_names
489591
+ bcast_actx_ary_types)}):
490-
return cls({bcast_init_args})
592+
if __debug__:
593+
if isinstance(arg1,
594+
{tup_str(bcast_actx_ary_types)}):
595+
warn("Broadcasting {cls} over array "
596+
f"context array type {{type(arg1)}} "
597+
"is deprecated "
598+
"and will no longer work in 2025."
599+
"There is no replacement as of right now. "
600+
"See the discussion in "
601+
"https://github.com/inducer/arraycontext/"
602+
"pull/190. ",
603+
DeprecationWarning, stacklevel=2)
604+
605+
return cls({bcast_init_args_arg2_is_outer})
606+
491607
return NotImplemented
492608
493609
cls.__r{dunder_name}__ = {fname}""")

0 commit comments

Comments
 (0)