diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 74adae96..a7eb67fd 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -50,6 +50,10 @@ ) from .container.dataclass import dataclass_array_container from .container.traversal import ( + bcast_left, + bcast_left_until_actx_array, + bcast_right, + bcast_right_until_actx_array, flat_size_and_dtype, flatten, freeze, @@ -129,6 +133,10 @@ "ScalarLike", "SerializationKey", "SerializedContainer", + "bcast_left", + "bcast_left_until_actx_array", + "bcast_right", + "bcast_right_until_actx_array", "dataclass_array_container", "deserialize_container", "flat_size_and_dtype", diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 73dee6d9..6966046e 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -413,10 +413,8 @@ def wrap(cls: Any) -> Any: warn( f"Broadcasting array context array types across {cls} " "has been implicitly " - "enabled. As of 2025, this will no longer work. " - "There is no replacement as of right now. " - "See the discussion in " - "https://github.com/inducer/arraycontext/pull/190. " + "enabled. As of 2026, this will no longer work. " + "Use bcast_left/right_until_actx_array instead." "To opt out now (and avoid this warning), " "pass _bcast_actx_array_type=False.", DeprecationWarning, stacklevel=2) @@ -603,11 +601,9 @@ def {fname}(arg1): if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): warn("Broadcasting {cls} over array " f"context array type {{type(arg2)}} is deprecated " - "and will no longer work in 2025. " - "There is no replacement as of right now. " - "See the discussion in " - "https://github.com/inducer/arraycontext/" - "pull/190. ", + "and will no longer work in 2026. " + "Use bcast_left/right_until_actx_array " + "instead.", DeprecationWarning, stacklevel=2) return cls({bcast_init_args_arg1_is_outer}) @@ -654,11 +650,9 @@ def {fname}(arg2, arg1): warn("Broadcasting {cls} over array " f"context array type {{type(arg1)}} " "is deprecated " - "and will no longer work in 2025." - "There is no replacement as of right now. " - "See the discussion in " - "https://github.com/inducer/arraycontext/" - "pull/190. ", + "and will no longer work in 2026." + "Use bcast_left/right_until_actx_array " + "instead.", DeprecationWarning, stacklevel=2) return cls({bcast_init_args_arg2_is_outer}) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index ef7141f5..326e7cab 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -78,6 +78,7 @@ import numpy as np from arraycontext.container import ( + ArithArrayContainer, ArrayContainer, NotAnArrayContainerError, SerializationKey, @@ -88,6 +89,7 @@ from arraycontext.context import ( Array, ArrayContext, + ArrayOrArithContainer, ArrayOrContainer, ArrayOrContainerOrScalar, ArrayOrContainerT, @@ -987,4 +989,79 @@ def treat_as_scalar(x: Any) -> bool: # }}} + +# {{{ + +def bcast_left( + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], + ArrayOrArithContainer], + left: ArrayOrArithContainer, + right: ArithArrayContainer, + ) -> ArrayOrArithContainer: + try: + serialized = serialize_container(right) + except NotAnArrayContainerError: + return op(left, right) + + return deserialize_container(right, [ + (k, op(left, right_v)) for k, right_v in serialized]) + + +def bcast_right( + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], + ArrayOrArithContainer], + left: ArrayOrArithContainer, + right: ArithArrayContainer, + ) -> ArrayOrArithContainer: + try: + serialized = serialize_container(left) + except NotAnArrayContainerError: + return op(left, right) + + return deserialize_container(right, [ + (k, op(left_v, right)) for k, left_v in serialized]) + + +def bcast_left_until_actx_array( + actx: ArrayContext, + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], + ArrayOrArithContainer], + left: ArrayOrArithContainer, + right: ArithArrayContainer, + ) -> ArrayOrArithContainer: + try: + serialized = serialize_container(right) + except NotAnArrayContainerError: + return op(left, right) + + return deserialize_container(right, [ + (k, op(left, right_v) + if isinstance(right_v, actx.array_types) else + bcast_left_until_actx_array(actx, op, left, right_v) + ) + for k, right_v in serialized]) + + +def bcast_right_until_actx_array( + actx: ArrayContext, + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], + ArrayOrArithContainer], + left: ArrayOrArithContainer, + right: ArithArrayContainer, + ) -> ArrayOrArithContainer: + try: + serialized = serialize_container(left) + except NotAnArrayContainerError: + return op(left, right) + + return deserialize_container(right, [ + (k, op(left_v, right) + if isinstance(left_v, actx.array_types) else + bcast_right_until_actx_array(actx, op, left_v, right) + ) + for k, left_v in serialized]) + +# }}} + + # vim: foldmethod=marker