Skip to content

Bcast object wrappers, attempt 2 #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
serialize_container,
)
from .container.arithmetic import (
BcastUntilActxArray,
with_container_arithmetic,
)
from .container.dataclass import dataclass_array_container
Expand Down Expand Up @@ -115,6 +116,7 @@
"ArrayOrContainerOrScalarT",
"ArrayOrContainerT",
"ArrayT",
"BcastUntilActxArray",
"CommonSubexpressionTag",
"EagerJAXArrayContext",
"ElementwiseMapKernelTag",
Expand Down
2 changes: 2 additions & 0 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...


ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)
Expand Down
154 changes: 144 additions & 10 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
.. currentmodule:: arraycontext

.. autofunction:: with_container_arithmetic

.. autoclass:: BcastUntilActxArray
"""


Expand Down Expand Up @@ -34,12 +36,23 @@
"""

import enum
import operator
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import partialmethod
from numbers import Number
from typing import Any, TypeVar
from warnings import warn

import numpy as np

from arraycontext.container import (
NotAnArrayContainerError,
deserialize_container,
serialize_container,
)
from arraycontext.context import ArrayContext, ArrayOrContainer


# {{{ with_container_arithmetic

Expand Down Expand Up @@ -142,8 +155,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
warn(
"Broadcasting container against non-object numpy array. "
"This was never documented to work and will now stop working in "
"2025. Convert the array to an object array to preserve the "
"current semantics.", DeprecationWarning, stacklevel=3)
"2025. Convert the array to an object array or use "
"arraycontext.BcastUntilActxArray (or similar) to obtain the desired "
"broadcasting semantics.", DeprecationWarning, stacklevel=3)
return True
else:
return False
Expand Down Expand Up @@ -207,6 +221,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`

Each operator class also includes the "reverse" operators if applicable.

.. note::

For the generated binary arithmetic operators, if certain types
should be broadcast over the container (with the container as the
'outer' structure) but are not handled in this way by their types,
you may wrap them in :class:`BcastUntilActxArray` to achieve
the desired semantics.

.. note::

To generate the code implementing the operators, this function relies on
Expand Down Expand Up @@ -402,8 +424,9 @@ def wrap(cls: Any) -> Any:
warn(
f"Broadcasting array context array types across {cls} "
"has been explicitly "
"enabled. As of 2025, this will stop working. "
"There is no replacement as of right now. "
"enabled. As of 2026, this will stop working. "
"Use arraycontext.Bcast* object wrappers for "
"roughly equivalent functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"To opt out now (and avoid this warning), "
Expand All @@ -413,8 +436,9 @@ def wrap(cls: Any) -> Any:
warn(
f"Broadcasting array context array types across {cls} "
"has been implicitly "
"enabled. As of 2025, this will no longer work. "
"There is no replacement as of right now. "
"enabled. As of 2026, this will no longer work. "
"Use arraycontext.Bcast* object wrappers for "
"roughly equivalent functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"To opt out now (and avoid this warning), "
Expand Down Expand Up @@ -603,8 +627,9 @@ def {fname}(arg1):
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg2)}} is deprecated "
"and will no longer work in 2025. "
"There is no replacement as of right now. "
"and will no longer work in 2026. "
"Use arraycontext.Bcast* object wrappers for "
"roughly equivalent functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
Expand Down Expand Up @@ -654,8 +679,10 @@ def {fname}(arg2, arg1):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg1)}} "
"is deprecated "
"and will no longer work in 2025."
"There is no replacement as of right now. "
"and will no longer work in 2026."
"Use arraycontext.Bcast* object "
"wrappers for roughly equivalent "
"functionality. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
Expand Down Expand Up @@ -687,4 +714,111 @@ def {fname}(arg2, arg1):
# }}}


# {{{ Bcast object-ified broadcast rules

# Possible advantages of the "Bcast" broadcast-rule-as-object design:
#
# - If one rule does not fit the user's need, they can straightforwardly use
# another.
#
# - It's straightforward to find where certain broadcast rules are used.
#
# - The broadcast rule can contain more state. For example, it's now easy
# for the rule to know what array context should be used to determine
# actx array types.
#
# Possible downsides of the "Bcast" broadcast-rule-as-object design:
#
# - User code is a bit more wordy.

@dataclass(frozen=True)
class BcastUntilActxArray:
"""
An operator-overloading wrapper around an object (*broadcastee*) that should be
broadcast across array containers until the 'opposite' operand is one of the
:attr:`~arraycontext.ArrayContext.array_types`
of *actx* or a :class:`~numbers.Number`.

Suggested usage pattern::

bcast = functools.partial(BcastUntilActxArray, actx)

container + bcast(actx_array)

.. automethod:: __init__
"""

array_context: ArrayContext
broadcastee: ArrayOrContainer

_stop_types: tuple[type, ...] = field(init=False)

def __post_init__(self) -> None:
object.__setattr__(
self, "_stop_types", (*self.array_context.array_types, Number))

def _binary_op(self,
op: Callable[
[ArrayOrContainer, ArrayOrContainer],
ArrayOrContainer
],
right: ArrayOrContainer
) -> ArrayOrContainer:
try:
serialized = serialize_container(right)
except NotAnArrayContainerError:
return op(self.broadcastee, right)

return deserialize_container(right, [
(k, op(self.broadcastee, right_v)
if isinstance(right_v, self._stop_types) else
self._binary_op(op, right_v)
)
for k, right_v in serialized])

def _rev_binary_op(self,
op: Callable[
[ArrayOrContainer, ArrayOrContainer],
ArrayOrContainer
],
left: ArrayOrContainer
) -> ArrayOrContainer:
try:
serialized = serialize_container(left)
except NotAnArrayContainerError:
return op(left, self.broadcastee)

return deserialize_container(left, [
(k, op(left_v, self.broadcastee)
if isinstance(left_v, self._stop_types) else
self._rev_binary_op(op, left_v)
)
for k, left_v in serialized])

__add__ = partialmethod(_binary_op, operator.add)
__radd__ = partialmethod(_rev_binary_op, operator.add)
__sub__ = partialmethod(_binary_op, operator.sub)
__rsub__ = partialmethod(_rev_binary_op, operator.sub)
__mul__ = partialmethod(_binary_op, operator.mul)
__rmul__ = partialmethod(_rev_binary_op, operator.mul)
__truediv__ = partialmethod(_binary_op, operator.truediv)
__rtruediv__ = partialmethod(_rev_binary_op, operator.truediv)
__floordiv__ = partialmethod(_binary_op, operator.floordiv)
__rfloordiv__ = partialmethod(_rev_binary_op, operator.floordiv)
__mod__ = partialmethod(_binary_op, operator.mod)
__rmod__ = partialmethod(_rev_binary_op, operator.mod)
__pow__ = partialmethod(_binary_op, operator.pow)
__rpow__ = partialmethod(_rev_binary_op, operator.pow)

__lshift__ = partialmethod(_binary_op, operator.lshift)
__rlshift__ = partialmethod(_rev_binary_op, operator.lshift)
__rshift__ = partialmethod(_binary_op, operator.rshift)
__rrshift__ = partialmethod(_rev_binary_op, operator.rshift)
__and__ = partialmethod(_binary_op, operator.and_)
__rand__ = partialmethod(_rev_binary_op, operator.and_)
__or__ = partialmethod(_binary_op, operator.or_)
__ror__ = partialmethod(_rev_binary_op, operator.or_)

# }}}

# vim: foldmethod=marker
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@

nitpick_ignore_regex = [
["py:class", r"arraycontext\.context\.ContainerOrScalarT"],
["py:class", r"ArrayOrContainer"],
]
87 changes: 85 additions & 2 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytools.tag import Tag

from arraycontext import (
BcastUntilActxArray,
EagerJAXArrayContext,
NumpyArrayContext,
PyOpenCLArrayContext,
Expand Down Expand Up @@ -1189,7 +1190,7 @@ def test_container_equality(actx_factory):
# }}}


# {{{ test_no_leaf_array_type_broadcasting
# {{{ test_leaf_array_type_broadcasting

def test_no_leaf_array_type_broadcasting(actx_factory):
from testlib import Foo
Expand All @@ -1198,14 +1199,85 @@ def test_no_leaf_array_type_broadcasting(actx_factory):

dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))
foo = Foo(dof_ary)
bar = foo + 4

bcast = partial(BcastUntilActxArray, actx)

actx_ary = actx.from_numpy(4*np.ones((3, )))
with pytest.raises(TypeError):
foo + actx_ary

baz = foo + bcast(actx_ary)
qux = bcast(actx_ary) + foo

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(baz.u[0]))

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(qux.u[0]))

baz = foo + bcast(actx_ary)
qux = bcast(actx_ary) + foo

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(baz.u[0]))

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(qux.u[0]))

mc = MyContainer(
name="hi",
mass=dof_ary,
momentum=make_obj_array([dof_ary, dof_ary]),
enthalpy=dof_ary)

with pytest.raises(TypeError):
foo + actx.from_numpy(np.array(4))
mc_op = mc + actx_ary

mom_op = mc.momentum + bcast(actx_ary)
np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0]))

mc_op = mc + bcast(actx_ary)
np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0]))
np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0]))

mom_op = mc.momentum + bcast(actx_ary)
np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0]))

mc_op = mc + bcast(actx_ary)
np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0]))
np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0]))

def _actx_allows_scalar_broadcast(actx):
if not isinstance(actx, PyOpenCLArrayContext):
return True
else:
import pyopencl as cl

# See https://github.com/inducer/pyopencl/issues/498
return cl.version.VERSION > (2021, 2, 5)

if _actx_allows_scalar_broadcast(actx):
with pytest.raises(TypeError):
foo + actx.from_numpy(np.array(4))

quuz = bcast(actx.from_numpy(np.array(4))) + foo
quux = foo + bcast(actx.from_numpy(np.array(4)))

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(quux.u[0]))

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(quuz.u[0]))

quuz = bcast(actx.from_numpy(np.array(4))) + foo
quux = foo + bcast(actx.from_numpy(np.array(4)))

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(quux.u[0]))

np.testing.assert_allclose(actx.to_numpy(bar.u[0]),
actx.to_numpy(quuz.u[0]))
# }}}


Expand All @@ -1220,6 +1292,8 @@ def test_outer(actx_factory):
b_ary_of_dofs = a_ary_of_dofs + 1
b_bcast_dc_of_dofs = a_bcast_dc_of_dofs + 1

bcast = partial(BcastUntilActxArray, actx)

from arraycontext import outer

def equal(a, b):
Expand Down Expand Up @@ -1274,6 +1348,15 @@ def equal(a, b):
b_bcast_dc_of_dofs.momentum),
enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy))

# Array context scalars
two = actx.from_numpy(np.array(2))
assert equal(
outer(bcast(two), b_bcast_dc_of_dofs),
bcast(two)*b_bcast_dc_of_dofs)
assert equal(
outer(a_bcast_dc_of_dofs, bcast(two)),
a_bcast_dc_of_dofs*bcast(two))

# }}}


Expand Down
Loading