Skip to content

Commit fa488f1

Browse files
inducermajosm
andcommitted
Add BcastUntilActxArray
Co-authored-by: Matt Smith <[email protected]>
1 parent 602a63f commit fa488f1

File tree

4 files changed

+232
-12
lines changed

4 files changed

+232
-12
lines changed

arraycontext/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
serialize_container,
4747
)
4848
from .container.arithmetic import (
49+
BcastUntilActxArray,
4950
with_container_arithmetic,
5051
)
5152
from .container.dataclass import dataclass_array_container
@@ -115,6 +116,7 @@
115116
"ArrayOrContainerOrScalarT",
116117
"ArrayOrContainerT",
117118
"ArrayT",
119+
"BcastUntilActxArray",
118120
"CommonSubexpressionTag",
119121
"EagerJAXArrayContext",
120122
"ElementwiseMapKernelTag",

arraycontext/container/arithmetic.py

+144-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
.. currentmodule:: arraycontext
77
88
.. autofunction:: with_container_arithmetic
9+
10+
.. autoclass:: BcastUntilActxArray
911
"""
1012

1113

@@ -34,12 +36,23 @@
3436
"""
3537

3638
import enum
39+
import operator
3740
from collections.abc import Callable
41+
from dataclasses import dataclass, field
42+
from functools import partialmethod
43+
from numbers import Number
3844
from typing import Any, TypeVar
3945
from warnings import warn
4046

4147
import numpy as np
4248

49+
from arraycontext.container import (
50+
NotAnArrayContainerError,
51+
deserialize_container,
52+
serialize_container,
53+
)
54+
from arraycontext.context import ArrayContext, ArrayOrContainer
55+
4356

4457
# {{{ with_container_arithmetic
4558

@@ -142,8 +155,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
142155
warn(
143156
"Broadcasting container against non-object numpy array. "
144157
"This was never documented to work and will now stop working in "
145-
"2025. Convert the array to an object array to preserve the "
146-
"current semantics.", DeprecationWarning, stacklevel=3)
158+
"2025. Convert the array to an object array or use "
159+
"arraycontext.BcastUntilActxArray (or similar) to obtain the desired "
160+
"broadcasting semantics.", DeprecationWarning, stacklevel=3)
147161
return True
148162
else:
149163
return False
@@ -207,6 +221,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
207221
208222
Each operator class also includes the "reverse" operators if applicable.
209223
224+
.. note::
225+
226+
For the generated binary arithmetic operators, if certain types
227+
should be broadcast over the container (with the container as the
228+
'outer' structure) but are not handled in this way by their types,
229+
you may wrap them in :class:`BcastUntilActxArray` to achieve
230+
the desired semantics.
231+
210232
.. note::
211233
212234
To generate the code implementing the operators, this function relies on
@@ -402,8 +424,9 @@ def wrap(cls: Any) -> Any:
402424
warn(
403425
f"Broadcasting array context array types across {cls} "
404426
"has been explicitly "
405-
"enabled. As of 2025, this will stop working. "
406-
"There is no replacement as of right now. "
427+
"enabled. As of 2026, this will stop working. "
428+
"Use arraycontext.Bcast* object wrappers for "
429+
"roughly equivalent functionality. "
407430
"See the discussion in "
408431
"https://github.com/inducer/arraycontext/pull/190. "
409432
"To opt out now (and avoid this warning), "
@@ -413,8 +436,9 @@ def wrap(cls: Any) -> Any:
413436
warn(
414437
f"Broadcasting array context array types across {cls} "
415438
"has been implicitly "
416-
"enabled. As of 2025, this will no longer work. "
417-
"There is no replacement as of right now. "
439+
"enabled. As of 2026, this will no longer work. "
440+
"Use arraycontext.Bcast* object wrappers for "
441+
"roughly equivalent functionality. "
418442
"See the discussion in "
419443
"https://github.com/inducer/arraycontext/pull/190. "
420444
"To opt out now (and avoid this warning), "
@@ -603,8 +627,9 @@ def {fname}(arg1):
603627
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
604628
warn("Broadcasting {cls} over array "
605629
f"context array type {{type(arg2)}} is deprecated "
606-
"and will no longer work in 2025. "
607-
"There is no replacement as of right now. "
630+
"and will no longer work in 2026. "
631+
"Use arraycontext.Bcast* object wrappers for "
632+
"roughly equivalent functionality. "
608633
"See the discussion in "
609634
"https://github.com/inducer/arraycontext/"
610635
"pull/190. ",
@@ -654,8 +679,10 @@ def {fname}(arg2, arg1):
654679
warn("Broadcasting {cls} over array "
655680
f"context array type {{type(arg1)}} "
656681
"is deprecated "
657-
"and will no longer work in 2025."
658-
"There is no replacement as of right now. "
682+
"and will no longer work in 2026."
683+
"Use arraycontext.Bcast* object "
684+
"wrappers for roughly equivalent "
685+
"functionality. "
659686
"See the discussion in "
660687
"https://github.com/inducer/arraycontext/"
661688
"pull/190. ",
@@ -687,4 +714,111 @@ def {fname}(arg2, arg1):
687714
# }}}
688715

689716

717+
# {{{ Bcast object-ified broadcast rules
718+
719+
# Possible advantages of the "Bcast" broadcast-rule-as-object design:
720+
#
721+
# - If one rule does not fit the user's need, they can straightforwardly use
722+
# another.
723+
#
724+
# - It's straightforward to find where certain broadcast rules are used.
725+
#
726+
# - The broadcast rule can contain more state. For example, it's now easy
727+
# for the rule to know what array context should be used to determine
728+
# actx array types.
729+
#
730+
# Possible downsides of the "Bcast" broadcast-rule-as-object design:
731+
#
732+
# - User code is a bit more wordy.
733+
734+
@dataclass(frozen=True)
735+
class BcastUntilActxArray:
736+
"""
737+
An operator-overloading wrapper around an object (*broadcastee*) that should be
738+
broadcast across array containers until the 'opposite' operand is one of the
739+
:attr:`~arraycontext.ArrayContext.array_types`
740+
of *actx* or a :class:`~numbers.Number`.
741+
742+
Suggested usage pattern::
743+
744+
bcast = functools.partial(BcastUntilActxArray, actx)
745+
746+
container + bcast(actx_array)
747+
748+
.. automethod:: __init__
749+
"""
750+
751+
array_context: ArrayContext
752+
broadcastee: ArrayOrContainer
753+
754+
_stop_types: tuple[type, ...] = field(init=False)
755+
756+
def __post_init__(self) -> None:
757+
object.__setattr__(
758+
self, "_stop_types", (*self.array_context.array_types, Number))
759+
760+
def _binary_op(self,
761+
op: Callable[
762+
[ArrayOrContainer, ArrayOrContainer],
763+
ArrayOrContainer
764+
],
765+
right: ArrayOrContainer
766+
) -> ArrayOrContainer:
767+
try:
768+
serialized = serialize_container(right)
769+
except NotAnArrayContainerError:
770+
return op(self.broadcastee, right)
771+
772+
return deserialize_container(right, [
773+
(k, op(self.broadcastee, right_v)
774+
if isinstance(right_v, self._stop_types) else
775+
self._binary_op(op, right_v)
776+
)
777+
for k, right_v in serialized])
778+
779+
def _rev_binary_op(self,
780+
op: Callable[
781+
[ArrayOrContainer, ArrayOrContainer],
782+
ArrayOrContainer
783+
],
784+
left: ArrayOrContainer
785+
) -> ArrayOrContainer:
786+
try:
787+
serialized = serialize_container(left)
788+
except NotAnArrayContainerError:
789+
return op(left, self.broadcastee)
790+
791+
return deserialize_container(left, [
792+
(k, op(left_v, self.broadcastee)
793+
if isinstance(left_v, self._stop_types) else
794+
self._rev_binary_op(op, left_v)
795+
)
796+
for k, left_v in serialized])
797+
798+
__add__ = partialmethod(_binary_op, operator.add)
799+
__radd__ = partialmethod(_rev_binary_op, operator.add)
800+
__sub__ = partialmethod(_binary_op, operator.sub)
801+
__rsub__ = partialmethod(_rev_binary_op, operator.sub)
802+
__mul__ = partialmethod(_binary_op, operator.mul)
803+
__rmul__ = partialmethod(_rev_binary_op, operator.mul)
804+
__truediv__ = partialmethod(_binary_op, operator.truediv)
805+
__rtruediv__ = partialmethod(_rev_binary_op, operator.truediv)
806+
__floordiv__ = partialmethod(_binary_op, operator.floordiv)
807+
__rfloordiv__ = partialmethod(_rev_binary_op, operator.floordiv)
808+
__mod__ = partialmethod(_binary_op, operator.mod)
809+
__rmod__ = partialmethod(_rev_binary_op, operator.mod)
810+
__pow__ = partialmethod(_binary_op, operator.pow)
811+
__rpow__ = partialmethod(_rev_binary_op, operator.pow)
812+
813+
__lshift__ = partialmethod(_binary_op, operator.lshift)
814+
__rlshift__ = partialmethod(_rev_binary_op, operator.lshift)
815+
__rshift__ = partialmethod(_binary_op, operator.rshift)
816+
__rrshift__ = partialmethod(_rev_binary_op, operator.rshift)
817+
__and__ = partialmethod(_binary_op, operator.and_)
818+
__rand__ = partialmethod(_rev_binary_op, operator.and_)
819+
__or__ = partialmethod(_binary_op, operator.or_)
820+
__ror__ = partialmethod(_rev_binary_op, operator.or_)
821+
822+
# }}}
823+
690824
# vim: foldmethod=marker

doc/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@
4242

4343
nitpick_ignore_regex = [
4444
["py:class", r"arraycontext\.context\.ContainerOrScalarT"],
45+
["py:class", r"ArrayOrContainer"],
4546
]

test/test_arraycontext.py

+85-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pytools.tag import Tag
3535

3636
from arraycontext import (
37+
BcastUntilActxArray,
3738
EagerJAXArrayContext,
3839
NumpyArrayContext,
3940
PyOpenCLArrayContext,
@@ -1189,7 +1190,7 @@ def test_container_equality(actx_factory):
11891190
# }}}
11901191

11911192

1192-
# {{{ test_no_leaf_array_type_broadcasting
1193+
# {{{ test_leaf_array_type_broadcasting
11931194

11941195
def test_no_leaf_array_type_broadcasting(actx_factory):
11951196
from testlib import Foo
@@ -1198,14 +1199,85 @@ def test_no_leaf_array_type_broadcasting(actx_factory):
11981199

11991200
dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))
12001201
foo = Foo(dof_ary)
1202+
bar = foo + 4
1203+
1204+
bcast = partial(BcastUntilActxArray, actx)
12011205

12021206
actx_ary = actx.from_numpy(4*np.ones((3, )))
12031207
with pytest.raises(TypeError):
12041208
foo + actx_ary
12051209

1210+
baz = foo + bcast(actx_ary)
1211+
qux = bcast(actx_ary) + foo
1212+
1213+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1214+
actx.to_numpy(baz.u[0]))
1215+
1216+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1217+
actx.to_numpy(qux.u[0]))
1218+
1219+
baz = foo + bcast(actx_ary)
1220+
qux = bcast(actx_ary) + foo
1221+
1222+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1223+
actx.to_numpy(baz.u[0]))
1224+
1225+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1226+
actx.to_numpy(qux.u[0]))
1227+
1228+
mc = MyContainer(
1229+
name="hi",
1230+
mass=dof_ary,
1231+
momentum=make_obj_array([dof_ary, dof_ary]),
1232+
enthalpy=dof_ary)
1233+
12061234
with pytest.raises(TypeError):
1207-
foo + actx.from_numpy(np.array(4))
1235+
mc_op = mc + actx_ary
1236+
1237+
mom_op = mc.momentum + bcast(actx_ary)
1238+
np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0]))
1239+
1240+
mc_op = mc + bcast(actx_ary)
1241+
np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0]))
1242+
np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0]))
1243+
1244+
mom_op = mc.momentum + bcast(actx_ary)
1245+
np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0]))
1246+
1247+
mc_op = mc + bcast(actx_ary)
1248+
np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0]))
1249+
np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0]))
1250+
1251+
def _actx_allows_scalar_broadcast(actx):
1252+
if not isinstance(actx, PyOpenCLArrayContext):
1253+
return True
1254+
else:
1255+
import pyopencl as cl
1256+
1257+
# See https://github.com/inducer/pyopencl/issues/498
1258+
return cl.version.VERSION > (2021, 2, 5)
1259+
1260+
if _actx_allows_scalar_broadcast(actx):
1261+
with pytest.raises(TypeError):
1262+
foo + actx.from_numpy(np.array(4))
1263+
1264+
quuz = bcast(actx.from_numpy(np.array(4))) + foo
1265+
quux = foo + bcast(actx.from_numpy(np.array(4)))
1266+
1267+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1268+
actx.to_numpy(quux.u[0]))
1269+
1270+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1271+
actx.to_numpy(quuz.u[0]))
12081272

1273+
quuz = bcast(actx.from_numpy(np.array(4))) + foo
1274+
quux = foo + bcast(actx.from_numpy(np.array(4)))
1275+
1276+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1277+
actx.to_numpy(quux.u[0]))
1278+
1279+
np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
1280+
actx.to_numpy(quuz.u[0]))
12091281
# }}}
12101282

12111283

@@ -1220,6 +1292,8 @@ def test_outer(actx_factory):
12201292
b_ary_of_dofs = a_ary_of_dofs + 1
12211293
b_bcast_dc_of_dofs = a_bcast_dc_of_dofs + 1
12221294

1295+
bcast = partial(BcastUntilActxArray, actx)
1296+
12231297
from arraycontext import outer
12241298

12251299
def equal(a, b):
@@ -1274,6 +1348,15 @@ def equal(a, b):
12741348
b_bcast_dc_of_dofs.momentum),
12751349
enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy))
12761350

1351+
# Array context scalars
1352+
two = actx.from_numpy(np.array(2))
1353+
assert equal(
1354+
outer(bcast(two), b_bcast_dc_of_dofs),
1355+
bcast(two)*b_bcast_dc_of_dofs)
1356+
assert equal(
1357+
outer(a_bcast_dc_of_dofs, bcast(two)),
1358+
a_bcast_dc_of_dofs*bcast(two))
1359+
12771360
# }}}
12781361

12791362

0 commit comments

Comments
 (0)