diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 74adae96..4abe4a54 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -46,6 +46,7 @@ serialize_container, ) from .container.arithmetic import ( + BcastUntilActxArray, with_container_arithmetic, ) from .container.dataclass import dataclass_array_container @@ -115,6 +116,7 @@ "ArrayOrContainerOrScalarT", "ArrayOrContainerT", "ArrayT", + "BcastUntilActxArray", "CommonSubexpressionTag", "EagerJAXArrayContext", "ElementwiseMapKernelTag", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index e70b51df..1ebc2cb2 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -168,6 +168,8 @@ def __mul__(self, other: ArrayOrScalar | Self) -> Self: ... def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ... def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ... def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ... + def __pow__(self, other: ArrayOrScalar | Self) -> Self: ... + def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 73dee6d9..06ffe738 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -6,6 +6,8 @@ .. currentmodule:: arraycontext .. autofunction:: with_container_arithmetic + +.. autoclass:: BcastUntilActxArray """ @@ -34,12 +36,23 @@ """ import enum +import operator from collections.abc import Callable +from dataclasses import dataclass, field +from functools import partialmethod +from numbers import Number from typing import Any, TypeVar from warnings import warn import numpy as np +from arraycontext.container import ( + NotAnArrayContainerError, + deserialize_container, + serialize_container, +) +from arraycontext.context import ArrayContext, ArrayOrContainer + # {{{ with_container_arithmetic @@ -142,8 +155,9 @@ def __instancecheck__(cls, instance: Any) -> bool: warn( "Broadcasting container against non-object numpy array. " "This was never documented to work and will now stop working in " - "2025. Convert the array to an object array to preserve the " - "current semantics.", DeprecationWarning, stacklevel=3) + "2025. Convert the array to an object array or use " + "arraycontext.BcastUntilActxArray (or similar) to obtain the desired " + "broadcasting semantics.", DeprecationWarning, stacklevel=3) return True else: return False @@ -207,6 +221,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__` Each operator class also includes the "reverse" operators if applicable. + .. note:: + + For the generated binary arithmetic operators, if certain types + should be broadcast over the container (with the container as the + 'outer' structure) but are not handled in this way by their types, + you may wrap them in :class:`BcastUntilActxArray` to achieve + the desired semantics. + .. note:: To generate the code implementing the operators, this function relies on @@ -402,8 +424,9 @@ def wrap(cls: Any) -> Any: warn( f"Broadcasting array context array types across {cls} " "has been explicitly " - "enabled. As of 2025, this will stop working. " - "There is no replacement as of right now. " + "enabled. As of 2026, this will stop working. " + "Use arraycontext.Bcast* object wrappers for " + "roughly equivalent functionality. " "See the discussion in " "https://github.com/inducer/arraycontext/pull/190. " "To opt out now (and avoid this warning), " @@ -413,8 +436,9 @@ def wrap(cls: Any) -> Any: warn( f"Broadcasting array context array types across {cls} " "has been implicitly " - "enabled. As of 2025, this will no longer work. " - "There is no replacement as of right now. " + "enabled. As of 2026, this will no longer work. " + "Use arraycontext.Bcast* object wrappers for " + "roughly equivalent functionality. " "See the discussion in " "https://github.com/inducer/arraycontext/pull/190. " "To opt out now (and avoid this warning), " @@ -603,8 +627,9 @@ def {fname}(arg1): if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): warn("Broadcasting {cls} over array " f"context array type {{type(arg2)}} is deprecated " - "and will no longer work in 2025. " - "There is no replacement as of right now. " + "and will no longer work in 2026. " + "Use arraycontext.Bcast* object wrappers for " + "roughly equivalent functionality. " "See the discussion in " "https://github.com/inducer/arraycontext/" "pull/190. ", @@ -654,8 +679,10 @@ def {fname}(arg2, arg1): warn("Broadcasting {cls} over array " f"context array type {{type(arg1)}} " "is deprecated " - "and will no longer work in 2025." - "There is no replacement as of right now. " + "and will no longer work in 2026." + "Use arraycontext.Bcast* object " + "wrappers for roughly equivalent " + "functionality. " "See the discussion in " "https://github.com/inducer/arraycontext/" "pull/190. ", @@ -687,4 +714,111 @@ def {fname}(arg2, arg1): # }}} +# {{{ Bcast object-ified broadcast rules + +# Possible advantages of the "Bcast" broadcast-rule-as-object design: +# +# - If one rule does not fit the user's need, they can straightforwardly use +# another. +# +# - It's straightforward to find where certain broadcast rules are used. +# +# - The broadcast rule can contain more state. For example, it's now easy +# for the rule to know what array context should be used to determine +# actx array types. +# +# Possible downsides of the "Bcast" broadcast-rule-as-object design: +# +# - User code is a bit more wordy. + +@dataclass(frozen=True) +class BcastUntilActxArray: + """ + An operator-overloading wrapper around an object (*broadcastee*) that should be + broadcast across array containers until the 'opposite' operand is one of the + :attr:`~arraycontext.ArrayContext.array_types` + of *actx* or a :class:`~numbers.Number`. + + Suggested usage pattern:: + + bcast = functools.partial(BcastUntilActxArray, actx) + + container + bcast(actx_array) + + .. automethod:: __init__ + """ + + array_context: ArrayContext + broadcastee: ArrayOrContainer + + _stop_types: tuple[type, ...] = field(init=False) + + def __post_init__(self) -> None: + object.__setattr__( + self, "_stop_types", (*self.array_context.array_types, Number)) + + def _binary_op(self, + op: Callable[ + [ArrayOrContainer, ArrayOrContainer], + ArrayOrContainer + ], + right: ArrayOrContainer + ) -> ArrayOrContainer: + try: + serialized = serialize_container(right) + except NotAnArrayContainerError: + return op(self.broadcastee, right) + + return deserialize_container(right, [ + (k, op(self.broadcastee, right_v) + if isinstance(right_v, self._stop_types) else + self._binary_op(op, right_v) + ) + for k, right_v in serialized]) + + def _rev_binary_op(self, + op: Callable[ + [ArrayOrContainer, ArrayOrContainer], + ArrayOrContainer + ], + left: ArrayOrContainer + ) -> ArrayOrContainer: + try: + serialized = serialize_container(left) + except NotAnArrayContainerError: + return op(left, self.broadcastee) + + return deserialize_container(left, [ + (k, op(left_v, self.broadcastee) + if isinstance(left_v, self._stop_types) else + self._rev_binary_op(op, left_v) + ) + for k, left_v in serialized]) + + __add__ = partialmethod(_binary_op, operator.add) + __radd__ = partialmethod(_rev_binary_op, operator.add) + __sub__ = partialmethod(_binary_op, operator.sub) + __rsub__ = partialmethod(_rev_binary_op, operator.sub) + __mul__ = partialmethod(_binary_op, operator.mul) + __rmul__ = partialmethod(_rev_binary_op, operator.mul) + __truediv__ = partialmethod(_binary_op, operator.truediv) + __rtruediv__ = partialmethod(_rev_binary_op, operator.truediv) + __floordiv__ = partialmethod(_binary_op, operator.floordiv) + __rfloordiv__ = partialmethod(_rev_binary_op, operator.floordiv) + __mod__ = partialmethod(_binary_op, operator.mod) + __rmod__ = partialmethod(_rev_binary_op, operator.mod) + __pow__ = partialmethod(_binary_op, operator.pow) + __rpow__ = partialmethod(_rev_binary_op, operator.pow) + + __lshift__ = partialmethod(_binary_op, operator.lshift) + __rlshift__ = partialmethod(_rev_binary_op, operator.lshift) + __rshift__ = partialmethod(_binary_op, operator.rshift) + __rrshift__ = partialmethod(_rev_binary_op, operator.rshift) + __and__ = partialmethod(_binary_op, operator.and_) + __rand__ = partialmethod(_rev_binary_op, operator.and_) + __or__ = partialmethod(_binary_op, operator.or_) + __ror__ = partialmethod(_rev_binary_op, operator.or_) + +# }}} + # vim: foldmethod=marker diff --git a/doc/conf.py b/doc/conf.py index 0042ae57..f01e4072 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -42,4 +42,5 @@ nitpick_ignore_regex = [ ["py:class", r"arraycontext\.context\.ContainerOrScalarT"], + ["py:class", r"ArrayOrContainer"], ] diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 904b8ad9..31fa9e79 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -34,6 +34,7 @@ from pytools.tag import Tag from arraycontext import ( + BcastUntilActxArray, EagerJAXArrayContext, NumpyArrayContext, PyOpenCLArrayContext, @@ -1189,7 +1190,7 @@ def test_container_equality(actx_factory): # }}} -# {{{ test_no_leaf_array_type_broadcasting +# {{{ test_leaf_array_type_broadcasting def test_no_leaf_array_type_broadcasting(actx_factory): from testlib import Foo @@ -1198,14 +1199,85 @@ def test_no_leaf_array_type_broadcasting(actx_factory): dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )) foo = Foo(dof_ary) + bar = foo + 4 + + bcast = partial(BcastUntilActxArray, actx) actx_ary = actx.from_numpy(4*np.ones((3, ))) with pytest.raises(TypeError): foo + actx_ary + baz = foo + bcast(actx_ary) + qux = bcast(actx_ary) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(baz.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(qux.u[0])) + + baz = foo + bcast(actx_ary) + qux = bcast(actx_ary) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(baz.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(qux.u[0])) + + mc = MyContainer( + name="hi", + mass=dof_ary, + momentum=make_obj_array([dof_ary, dof_ary]), + enthalpy=dof_ary) + with pytest.raises(TypeError): - foo + actx.from_numpy(np.array(4)) + mc_op = mc + actx_ary + + mom_op = mc.momentum + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0])) + + mc_op = mc + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + + mom_op = mc.momentum + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0])) + + mc_op = mc + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + + def _actx_allows_scalar_broadcast(actx): + if not isinstance(actx, PyOpenCLArrayContext): + return True + else: + import pyopencl as cl + + # See https://github.com/inducer/pyopencl/issues/498 + return cl.version.VERSION > (2021, 2, 5) + + if _actx_allows_scalar_broadcast(actx): + with pytest.raises(TypeError): + foo + actx.from_numpy(np.array(4)) + + quuz = bcast(actx.from_numpy(np.array(4))) + foo + quux = foo + bcast(actx.from_numpy(np.array(4))) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quux.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quuz.u[0])) + quuz = bcast(actx.from_numpy(np.array(4))) + foo + quux = foo + bcast(actx.from_numpy(np.array(4))) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quux.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quuz.u[0])) # }}} @@ -1220,6 +1292,8 @@ def test_outer(actx_factory): b_ary_of_dofs = a_ary_of_dofs + 1 b_bcast_dc_of_dofs = a_bcast_dc_of_dofs + 1 + bcast = partial(BcastUntilActxArray, actx) + from arraycontext import outer def equal(a, b): @@ -1274,6 +1348,15 @@ def equal(a, b): b_bcast_dc_of_dofs.momentum), enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy)) + # Array context scalars + two = actx.from_numpy(np.array(2)) + assert equal( + outer(bcast(two), b_bcast_dc_of_dofs), + bcast(two)*b_bcast_dc_of_dofs) + assert equal( + outer(a_bcast_dc_of_dofs, bcast(two)), + a_bcast_dc_of_dofs*bcast(two)) + # }}}