Skip to content

Commit f37f298

Browse files
committed
Introduce Bcast object-ified broacasting rules
1 parent d8e8683 commit f37f298

File tree

3 files changed

+319
-21
lines changed

3 files changed

+319
-21
lines changed

arraycontext/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
serialize_container,
4747
)
4848
from .container.arithmetic import (
49+
Bcast,
50+
Bcast1Level,
51+
Bcast2Levels,
52+
Bcast3Levels,
53+
BcastNLevels,
54+
BcastUntilActxArray,
4955
with_container_arithmetic,
5056
)
5157
from .container.dataclass import dataclass_array_container
@@ -115,6 +121,12 @@
115121
"ArrayOrContainerOrScalarT",
116122
"ArrayOrContainerT",
117123
"ArrayT",
124+
"Bcast",
125+
"Bcast1Level",
126+
"Bcast2Levels",
127+
"Bcast3Levels",
128+
"BcastNLevels",
129+
"BcastUntilActxArray",
118130
"CommonSubexpressionTag",
119131
"EagerJAXArrayContext",
120132
"ElementwiseMapKernelTag",

arraycontext/container/arithmetic.py

+222-19
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@
66
.. currentmodule:: arraycontext
77
88
.. autofunction:: with_container_arithmetic
9+
.. autoclass:: Bcast
10+
.. autoclass:: BcastNLevels
11+
.. autoclass:: BcastUntilActxArray
12+
13+
.. function:: Bcast1
14+
15+
Like :class:`BcastNLevels` with *nlevels* set to 1.
16+
17+
.. function:: Bcast2
18+
19+
Like :class:`BcastNLevels` with *nlevels* set to 2.
20+
21+
.. function:: Bcast3
22+
23+
Like :class:`BcastNLevels` with *nlevels* set to 3.
924
"""
1025

1126

@@ -34,12 +49,18 @@
3449
"""
3550

3651
import enum
52+
from abc import ABC, abstractmethod
3753
from collections.abc import Callable
38-
from typing import Any, TypeVar
54+
from dataclasses import FrozenInstanceError
55+
from functools import partial
56+
from numbers import Number
57+
from typing import Any, ClassVar, TypeVar, Union
3958
from warnings import warn
4059

4160
import numpy as np
4261

62+
from arraycontext.context import ArrayContext, ArrayOrContainer
63+
4364

4465
# {{{ with_container_arithmetic
4566

@@ -142,8 +163,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
142163
warn(
143164
"Broadcasting container against non-object numpy array. "
144165
"This was never documented to work and will now stop working in "
145-
"2025. Convert the array to an object array to preserve the "
146-
"current semantics.", DeprecationWarning, stacklevel=3)
166+
"2025. Convert the array to an object array or use "
167+
"variants of arraycontext.Bcast to obtain the desired "
168+
"broadcasting semantics.", DeprecationWarning, stacklevel=3)
147169
return True
148170
else:
149171
return False
@@ -153,6 +175,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
153175
pass
154176

155177

178+
class Bcast:
179+
"""
180+
A wrapper object to force arithmetic generated by :func:`with_container_arithmetic`
181+
to broadcast *arg* across a container (with the container as the 'outer' structure).
182+
Since array containers are often nested in complex ways, different subclasses
183+
implement different rules on how broadcasting interacts with the hierarchy,
184+
with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing
185+
the most common.
186+
"""
187+
arg: ArrayOrContainer
188+
189+
# Accessing this attribute is cheaper than isinstance, so use that
190+
# to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand.
191+
_with_next_operand: ClassVar[bool]
192+
193+
def __init__(self, arg: ArrayOrContainer) -> None:
194+
object.__setattr__(self, "arg", arg)
195+
196+
def __setattr__(self, name: str, value: Any) -> None:
197+
raise FrozenInstanceError()
198+
199+
def __delattr__(self, name: str) -> None:
200+
raise FrozenInstanceError()
201+
202+
203+
class _BcastWithNextOperand(Bcast, ABC):
204+
"""
205+
A :class:`Bcast` object that gets to see who the next operand will be, in
206+
order to decide whether wrapping the child in :class:`Bcast` is still necessary.
207+
This is much more flexible, but also considerably more expensive, than
208+
:class:`_BcastWithoutNextOperand`.
209+
"""
210+
211+
_with_next_operand = True
212+
213+
# purposefully undocumented
214+
@abstractmethod
215+
def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer:
216+
...
217+
218+
219+
class _BcastWithoutNextOperand(Bcast, ABC):
220+
"""
221+
A :class:`Bcast` object that does not get to see who the next operand will be.
222+
"""
223+
_with_next_operand = False
224+
225+
# purposefully undocumented
226+
@abstractmethod
227+
def _rewrap(self) -> ArrayOrContainer:
228+
...
229+
230+
231+
class BcastNLevels(_BcastWithoutNextOperand):
232+
"""
233+
A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of
234+
array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as
235+
convenient aliases for the common cases.
236+
237+
Usage example::
238+
239+
container + Bcast2(actx_array)
240+
241+
.. note::
242+
243+
:mod:`numpy` object arrays do not count against the number of levels.
244+
245+
.. automethod:: __init__
246+
"""
247+
nlevels: int
248+
249+
def __init__(self, nlevels: int, arg: ArrayOrContainer) -> None:
250+
if nlevels < 1:
251+
raise ValueError("nlevels is expected to be one or greater.")
252+
253+
super().__init__(arg)
254+
object.__setattr__(self, "nlevels", nlevels)
255+
256+
def _rewrap(self) -> ArrayOrContainer:
257+
if self.nlevels == 1:
258+
return self.arg
259+
else:
260+
return BcastNLevels(self.nlevels-1, self.arg)
261+
262+
263+
Bcast1Level = partial(BcastNLevels, 1)
264+
Bcast2Levels = partial(BcastNLevels, 2)
265+
Bcast3Levels = partial(BcastNLevels, 3)
266+
267+
268+
class BcastUntilActxArray(_BcastWithNextOperand):
269+
"""
270+
A broadcast rule that broadcasts *arg* across array containers until
271+
the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
272+
of *actx*, or a :class:`~numbers.Number`.
273+
274+
Suggested usage pattern::
275+
276+
bcast = functools.partial(BcastUntilActxArray, actx)
277+
278+
container + bcast(actx_array)
279+
280+
.. automethod:: __init__
281+
"""
282+
actx: ArrayContext
283+
284+
def __init__(self,
285+
actx: ArrayContext,
286+
arg: ArrayOrContainer) -> None:
287+
super().__init__(arg)
288+
object.__setattr__(self, "actx", actx)
289+
290+
def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer:
291+
if isinstance(other_operand, (*self.actx.array_types, Number)):
292+
return self.arg
293+
else:
294+
return self
295+
296+
156297
def with_container_arithmetic(
157298
*,
158299
number_bcasts_across: bool | None = None,
@@ -207,6 +348,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
207348
208349
Each operator class also includes the "reverse" operators if applicable.
209350
351+
.. note::
352+
353+
For the generated binary arithmetic operators, if certain types
354+
should be broadcast over the container (with the container as the
355+
'outer' structure) but are not handled in this way by their types,
356+
you may wrap them in one of the :class:`Bcast` variants to achieve
357+
the desired semantics.
358+
210359
.. note::
211360
212361
To generate the code implementing the operators, this function relies on
@@ -239,6 +388,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
239388
#
240389
# - Broadcast rules are hard to change once established, particularly
241390
# because one cannot grep for their use.
391+
#
392+
# Possible advantages of the "Bcast" broadcast-rule-as-object design:
393+
#
394+
# - If one rule does not fit the user's need, they can straightforwardly use
395+
# another.
396+
#
397+
# - It's straightforward to find where certain broadcast rules are used.
398+
#
399+
# - The broadcast rule can contain more state. For example, it's now easy
400+
# for the rule to know what array context should be used to determine
401+
# actx array types.
402+
#
403+
# Possible downsides of the "Bcast" broadcast-rule-as-object design:
404+
#
405+
# - User code is a bit more wordy.
406+
#
407+
# - Rewrapping has the potential to be costly, especially in
408+
# _with_next_operand mode.
242409

243410
# {{{ handle inputs
244411

@@ -404,9 +571,8 @@ def wrap(cls: Any) -> Any:
404571
f"Broadcasting array context array types across {cls} "
405572
"has been explicitly "
406573
"enabled. As of 2025, this will stop working. "
407-
"There is no replacement as of right now. "
408-
"See the discussion in "
409-
"https://github.com/inducer/arraycontext/pull/190. "
574+
"Express these operations using arraycontext.Bcast variants "
575+
"instead. "
410576
"To opt out now (and avoid this warning), "
411577
"pass _bcast_actx_array_type=False. ",
412578
DeprecationWarning, stacklevel=2)
@@ -415,9 +581,8 @@ def wrap(cls: Any) -> Any:
415581
f"Broadcasting array context array types across {cls} "
416582
"has been implicitly "
417583
"enabled. As of 2025, this will no longer work. "
418-
"There is no replacement as of right now. "
419-
"See the discussion in "
420-
"https://github.com/inducer/arraycontext/pull/190. "
584+
"Express these operations using arraycontext.Bcast variants "
585+
"instead. "
421586
"To opt out now (and avoid this warning), "
422587
"pass _bcast_actx_array_type=False.",
423588
DeprecationWarning, stacklevel=2)
@@ -435,7 +600,7 @@ def wrap(cls: Any) -> Any:
435600
gen(f"""
436601
from numbers import Number
437602
import numpy as np
438-
from arraycontext import ArrayContainer
603+
from arraycontext import ArrayContainer, Bcast
439604
from warnings import warn
440605
441606
def _raise_if_actx_none(actx):
@@ -455,7 +620,8 @@ def is_numpy_array(arg):
455620
"behavior will change in 2025. If you would like the "
456621
"broadcasting behavior to stay the same, make sure "
457622
"to convert the passed numpy array to an "
458-
"object array.",
623+
"object array, or use arraycontext.Bcast to achieve "
624+
"the desired broadcasting semantics.",
459625
DeprecationWarning, stacklevel=3)
460626
return True
461627
else:
@@ -553,6 +719,33 @@ def {fname}(arg1):
553719
cls._serialize_init_arrays_code("arg2").items()
554720
})
555721

722+
def get_operand(arg: Union[tuple[str, str], str]) -> str:
723+
if isinstance(arg, tuple):
724+
entry, _container = arg
725+
return entry
726+
else:
727+
return arg
728+
729+
bcast_init_args_arg1_is_outer_with_rewrap = \
730+
cls._deserialize_init_arrays_code("arg1", {
731+
key_arg1:
732+
_format_binary_op_str(
733+
op_str, expr_arg1,
734+
f"arg2._rewrap({get_operand(expr_arg1)})")
735+
for key_arg1, expr_arg1 in
736+
cls._serialize_init_arrays_code("arg1").items()
737+
})
738+
bcast_init_args_arg2_is_outer_with_rewrap = \
739+
cls._deserialize_init_arrays_code("arg2", {
740+
key_arg2:
741+
_format_binary_op_str(
742+
op_str,
743+
f"arg1._rewrap({get_operand(expr_arg2)})",
744+
expr_arg2)
745+
for key_arg2, expr_arg2 in
746+
cls._serialize_init_arrays_code("arg2").items()
747+
})
748+
556749
# {{{ "forward" binary operators
557750

558751
gen(f"def {fname}(arg1, arg2):")
@@ -605,14 +798,19 @@ def {fname}(arg1):
605798
warn("Broadcasting {cls} over array "
606799
f"context array type {{type(arg2)}} is deprecated "
607800
"and will no longer work in 2025. "
608-
"There is no replacement as of right now. "
609-
"See the discussion in "
610-
"https://github.com/inducer/arraycontext/"
611-
"pull/190. ",
801+
"Use arraycontext.Bcast to achieve the desired "
802+
"broadcasting semantics.",
612803
DeprecationWarning, stacklevel=2)
613804
614805
return cls({bcast_init_args_arg1_is_outer})
615806
807+
if isinstance(arg2, Bcast):
808+
if arg2._with_next_operand:
809+
return cls({bcast_init_args_arg1_is_outer_with_rewrap})
810+
else:
811+
arg2 = arg2._rewrap()
812+
return cls({bcast_init_args_arg1_is_outer})
813+
616814
return NotImplemented
617815
""")
618816
gen(f"cls.__{dunder_name}__ = {fname}")
@@ -656,14 +854,19 @@ def {fname}(arg2, arg1):
656854
f"context array type {{type(arg1)}} "
657855
"is deprecated "
658856
"and will no longer work in 2025."
659-
"There is no replacement as of right now. "
660-
"See the discussion in "
661-
"https://github.com/inducer/arraycontext/"
662-
"pull/190. ",
857+
"Use arraycontext.Bcast to achieve the "
858+
"desired broadcasting semantics.",
663859
DeprecationWarning, stacklevel=2)
664860
665861
return cls({bcast_init_args_arg2_is_outer})
666862
863+
if isinstance(arg1, Bcast):
864+
if arg1._with_next_operand:
865+
return cls({bcast_init_args_arg2_is_outer_with_rewrap})
866+
else:
867+
arg1 = arg1._rewrap()
868+
return cls({bcast_init_args_arg2_is_outer})
869+
667870
return NotImplemented
668871
669872
cls.__r{dunder_name}__ = {fname}""")

0 commit comments

Comments
 (0)