Skip to content

Commit d9c25fc

Browse files
committed
Introduce Bcast object-ified broacasting rules
1 parent 6251b61 commit d9c25fc

File tree

3 files changed

+321
-23
lines changed

3 files changed

+321
-23
lines changed

arraycontext/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@
4444
serialize_container,
4545
)
4646
from .container.arithmetic import (
47+
Bcast,
48+
Bcast1Level,
49+
Bcast2Levels,
50+
Bcast3Levels,
51+
BcastNLevels,
52+
BcastUntilActxArray,
4753
with_container_arithmetic,
4854
)
4955
from .container.dataclass import dataclass_array_container
@@ -105,6 +111,12 @@
105111
"ArrayOrContainerOrScalarT",
106112
"ArrayOrContainerT",
107113
"ArrayT",
114+
"Bcast",
115+
"Bcast1Level",
116+
"Bcast2Levels",
117+
"Bcast3Levels",
118+
"BcastNLevels",
119+
"BcastUntilActxArray",
108120
"CommonSubexpressionTag",
109121
"EagerJAXArrayContext",
110122
"ElementwiseMapKernelTag",

arraycontext/container/arithmetic.py

+222-19
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@
66
.. currentmodule:: arraycontext
77
88
.. autofunction:: with_container_arithmetic
9+
.. autoclass:: Bcast
10+
.. autoclass:: BcastNLevels
11+
.. autoclass:: BcastUntilActxArray
12+
13+
.. function:: Bcast1
14+
15+
Like :class:`BcastNLevels` with *nlevels* set to 1.
16+
17+
.. function:: Bcast2
18+
19+
Like :class:`BcastNLevels` with *nlevels* set to 2.
20+
21+
.. function:: Bcast3
22+
23+
Like :class:`BcastNLevels` with *nlevels* set to 3.
924
"""
1025

1126

@@ -34,11 +49,17 @@
3449
"""
3550

3651
import enum
37-
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
52+
from abc import ABC, abstractmethod
53+
from dataclasses import FrozenInstanceError
54+
from functools import partial
55+
from numbers import Number
56+
from typing import Any, Callable, ClassVar, Optional, Tuple, TypeVar, Union
3857
from warnings import warn
3958

4059
import numpy as np
4160

61+
from arraycontext.context import ArrayContext, ArrayOrContainer
62+
4263

4364
# {{{ with_container_arithmetic
4465

@@ -147,8 +168,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
147168
warn(
148169
"Broadcasting container against non-object numpy array. "
149170
"This was never documented to work and will now stop working in "
150-
"2025. Convert the array to an object array to preserve the "
151-
"current semantics.", DeprecationWarning, stacklevel=3)
171+
"2025. Convert the array to an object array or use "
172+
"variants of arraycontext.Bcast to obtain the desired "
173+
"broadcasting semantics.", DeprecationWarning, stacklevel=3)
152174
return True
153175
else:
154176
return False
@@ -158,6 +180,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
158180
pass
159181

160182

183+
class Bcast:
184+
"""
185+
A wrapper object to force arithmetic generated by :func:`with_container_arithmetic`
186+
to broadcast *arg* across a container (with the container as the 'outer' structure).
187+
Since array containers are often nested in complex ways, different subclasses
188+
implement different rules on how broadcasting interacts with the hierarchy,
189+
with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing
190+
the most common.
191+
"""
192+
arg: ArrayOrContainer
193+
194+
# Accessing this attribute is cheaper than isinstance, so use that
195+
# to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand.
196+
_with_next_operand: ClassVar[bool]
197+
198+
def __init__(self, arg: ArrayOrContainer) -> None:
199+
object.__setattr__(self, "arg", arg)
200+
201+
def __setattr__(self, name: str, value: Any) -> None:
202+
raise FrozenInstanceError()
203+
204+
def __delattr__(self, name: str) -> None:
205+
raise FrozenInstanceError()
206+
207+
208+
class _BcastWithNextOperand(Bcast, ABC):
209+
"""
210+
A :class:`Bcast` object that gets to see who the next operand will be, in
211+
order to decide whether wrapping the child in :class:`Bcast` is still necessary.
212+
This is much more flexible, but also considerably more expensive, than
213+
:class:`_BcastWithoutNextOperand`.
214+
"""
215+
216+
_with_next_operand = True
217+
218+
# purposefully undocumented
219+
@abstractmethod
220+
def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer:
221+
...
222+
223+
224+
class _BcastWithoutNextOperand(Bcast, ABC):
225+
"""
226+
A :class:`Bcast` object that does not get to see who the next operand will be.
227+
"""
228+
_with_next_operand = False
229+
230+
# purposefully undocumented
231+
@abstractmethod
232+
def _rewrap(self) -> ArrayOrContainer:
233+
...
234+
235+
236+
class BcastNLevels(_BcastWithoutNextOperand):
237+
"""
238+
A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of
239+
array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as
240+
convenient aliases for the common cases.
241+
242+
Usage example::
243+
244+
container + Bcast2(actx_array)
245+
246+
.. note::
247+
248+
:mod:`numpy` object arrays do not count against the number of levels.
249+
250+
.. automethod:: __init__
251+
"""
252+
nlevels: int
253+
254+
def __init__(self, nlevels: int, arg: ArrayOrContainer) -> None:
255+
if nlevels < 1:
256+
raise ValueError("nlevels is expected to be one or greater.")
257+
258+
super().__init__(arg)
259+
object.__setattr__(self, "nlevels", nlevels)
260+
261+
def _rewrap(self) -> ArrayOrContainer:
262+
if self.nlevels == 1:
263+
return self.arg
264+
else:
265+
return BcastNLevels(self.nlevels-1, self.arg)
266+
267+
268+
Bcast1Level = partial(BcastNLevels, 1)
269+
Bcast2Levels = partial(BcastNLevels, 2)
270+
Bcast3Levels = partial(BcastNLevels, 3)
271+
272+
273+
class BcastUntilActxArray(_BcastWithNextOperand):
274+
"""
275+
A broadcast rule that broadcasts *arg* across array containers until
276+
the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
277+
of *actx*, or a :class:`~numbers.Number`.
278+
279+
Suggested usage pattern::
280+
281+
bcast = functools.partial(BcastUntilActxArray, actx)
282+
283+
container + bcast(actx_array)
284+
285+
.. automethod:: __init__
286+
"""
287+
actx: ArrayContext
288+
289+
def __init__(self,
290+
actx: ArrayContext,
291+
arg: ArrayOrContainer) -> None:
292+
super().__init__(arg)
293+
object.__setattr__(self, "actx", actx)
294+
295+
def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer:
296+
if isinstance(other_operand, (*self.actx.array_types, Number)):
297+
return self.arg
298+
else:
299+
return self
300+
301+
161302
def with_container_arithmetic(
162303
*,
163304
bcast_number: bool = True,
@@ -206,6 +347,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
206347
207348
Each operator class also includes the "reverse" operators if applicable.
208349
350+
.. note::
351+
352+
For the generated binary arithmetic operators, if certain types
353+
should be broadcast over the container (with the container as the
354+
'outer' structure) but are not handled in this way by their types,
355+
you may wrap them in one of the :class:`Bcast` variants to achieve
356+
the desired semantics.
357+
209358
.. note::
210359
211360
To generate the code implementing the operators, this function relies on
@@ -238,6 +387,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
238387
#
239388
# - Broadcast rules are hard to change once established, particularly
240389
# because one cannot grep for their use.
390+
#
391+
# Possible advantages of the "Bcast" broadcast-rule-as-object design:
392+
#
393+
# - If one rule does not fit the user's need, they can straightforwardly use
394+
# another.
395+
#
396+
# - It's straightforward to find where certain broadcast rules are used.
397+
#
398+
# - The broadcast rule can contain more state. For example, it's now easy
399+
# for the rule to know what array context should be used to determine
400+
# actx array types.
401+
#
402+
# Possible downsides of the "Bcast" broadcast-rule-as-object design:
403+
#
404+
# - User code is a bit more wordy.
405+
#
406+
# - Rewrapping has the potential to be costly, especially in
407+
# _with_next_operand mode.
241408

242409
# {{{ handle inputs
243410

@@ -349,9 +516,8 @@ def wrap(cls: Any) -> Any:
349516
f"Broadcasting array context array types across {cls} "
350517
"has been explicitly "
351518
"enabled. As of 2025, this will stop working. "
352-
"There is no replacement as of right now. "
353-
"See the discussion in "
354-
"https://github.com/inducer/arraycontext/pull/190. "
519+
"Express these operations using arraycontext.Bcast variants "
520+
"instead. "
355521
"To opt out now (and avoid this warning), "
356522
"pass _bcast_actx_array_type=False. ",
357523
DeprecationWarning, stacklevel=2)
@@ -360,9 +526,8 @@ def wrap(cls: Any) -> Any:
360526
f"Broadcasting array context array types across {cls} "
361527
"has been implicitly "
362528
"enabled. As of 2025, this will no longer work. "
363-
"There is no replacement as of right now. "
364-
"See the discussion in "
365-
"https://github.com/inducer/arraycontext/pull/190. "
529+
"Express these operations using arraycontext.Bcast variants "
530+
"instead. "
366531
"To opt out now (and avoid this warning), "
367532
"pass _bcast_actx_array_type=False.",
368533
DeprecationWarning, stacklevel=2)
@@ -380,7 +545,7 @@ def wrap(cls: Any) -> Any:
380545
gen(f"""
381546
from numbers import Number
382547
import numpy as np
383-
from arraycontext import ArrayContainer
548+
from arraycontext import ArrayContainer, Bcast
384549
from warnings import warn
385550
386551
def _raise_if_actx_none(actx):
@@ -400,7 +565,8 @@ def is_numpy_array(arg):
400565
"behavior will change in 2025. If you would like the "
401566
"broadcasting behavior to stay the same, make sure "
402567
"to convert the passed numpy array to an "
403-
"object array.",
568+
"object array, or use arraycontext.Bcast to achieve "
569+
"the desired broadcasting semantics.",
404570
DeprecationWarning, stacklevel=3)
405571
return True
406572
else:
@@ -492,6 +658,33 @@ def {fname}(arg1):
492658
cls._serialize_init_arrays_code("arg2").items()
493659
})
494660

661+
def get_operand(arg: Union[tuple[str, str], str]) -> str:
662+
if isinstance(arg, tuple):
663+
entry, _container = arg
664+
return entry
665+
else:
666+
return arg
667+
668+
bcast_init_args_arg1_is_outer_with_rewrap = \
669+
cls._deserialize_init_arrays_code("arg1", {
670+
key_arg1:
671+
_format_binary_op_str(
672+
op_str, expr_arg1,
673+
f"arg2._rewrap({get_operand(expr_arg1)})")
674+
for key_arg1, expr_arg1 in
675+
cls._serialize_init_arrays_code("arg1").items()
676+
})
677+
bcast_init_args_arg2_is_outer_with_rewrap = \
678+
cls._deserialize_init_arrays_code("arg2", {
679+
key_arg2:
680+
_format_binary_op_str(
681+
op_str,
682+
f"arg1._rewrap({get_operand(expr_arg2)})",
683+
expr_arg2)
684+
for key_arg2, expr_arg2 in
685+
cls._serialize_init_arrays_code("arg2").items()
686+
})
687+
495688
# {{{ "forward" binary operators
496689

497690
gen(f"def {fname}(arg1, arg2):")
@@ -544,14 +737,19 @@ def {fname}(arg1):
544737
warn("Broadcasting {cls} over array "
545738
f"context array type {{type(arg2)}} is deprecated "
546739
"and will no longer work in 2025. "
547-
"There is no replacement as of right now. "
548-
"See the discussion in "
549-
"https://github.com/inducer/arraycontext/"
550-
"pull/190. ",
740+
"Use arraycontext.Bcast to achieve the desired "
741+
"broadcasting semantics.",
551742
DeprecationWarning, stacklevel=2)
552743
553744
return cls({bcast_init_args_arg1_is_outer})
554745
746+
if isinstance(arg2, Bcast):
747+
if arg2._with_next_operand:
748+
return cls({bcast_init_args_arg1_is_outer_with_rewrap})
749+
else:
750+
arg2 = arg2._rewrap()
751+
return cls({bcast_init_args_arg1_is_outer})
752+
555753
return NotImplemented
556754
""")
557755
gen(f"cls.__{dunder_name}__ = {fname}")
@@ -595,14 +793,19 @@ def {fname}(arg2, arg1):
595793
f"context array type {{type(arg1)}} "
596794
"is deprecated "
597795
"and will no longer work in 2025."
598-
"There is no replacement as of right now. "
599-
"See the discussion in "
600-
"https://github.com/inducer/arraycontext/"
601-
"pull/190. ",
796+
"Use arraycontext.Bcast to achieve the "
797+
"desired broadcasting semantics.",
602798
DeprecationWarning, stacklevel=2)
603799
604800
return cls({bcast_init_args_arg2_is_outer})
605801
802+
if isinstance(arg1, Bcast):
803+
if arg1._with_next_operand:
804+
return cls({bcast_init_args_arg2_is_outer_with_rewrap})
805+
else:
806+
arg1 = arg1._rewrap()
807+
return cls({bcast_init_args_arg2_is_outer})
808+
606809
return NotImplemented
607810
608811
cls.__r{dunder_name}__ = {fname}""")

0 commit comments

Comments
 (0)