Skip to content

Commit 720e1ec

Browse files
inducermajosm
andcommitted
Add BcastUntilActxArray
Co-authored-by: Matt Smith <[email protected]>
1 parent 602a63f commit 720e1ec

File tree

2 files changed

+94
-8
lines changed

2 files changed

+94
-8
lines changed

arraycontext/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
serialize_container,
4747
)
4848
from .container.arithmetic import (
49+
BcastUntilActxArray,
4950
with_container_arithmetic,
5051
)
5152
from .container.dataclass import dataclass_array_container
@@ -115,6 +116,7 @@
115116
"ArrayOrContainerOrScalarT",
116117
"ArrayOrContainerT",
117118
"ArrayT",
119+
"BcastUntilActxArray",
118120
"CommonSubexpressionTag",
119121
"EagerJAXArrayContext",
120122
"ElementwiseMapKernelTag",

arraycontext/container/arithmetic.py

+92-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
.. currentmodule:: arraycontext
77
88
.. autofunction:: with_container_arithmetic
9+
10+
.. autoclass:: BcastUntilActxArray
911
"""
1012

1113

@@ -34,12 +36,22 @@
3436
"""
3537

3638
import enum
39+
import operator
3740
from collections.abc import Callable
41+
from dataclasses import dataclass
42+
from functools import partialmethod
3843
from typing import Any, TypeVar
3944
from warnings import warn
4045

4146
import numpy as np
4247

48+
from arraycontext.container import (
49+
NotAnArrayContainerError,
50+
deserialize_container,
51+
serialize_container,
52+
)
53+
from arraycontext.context import ArrayContext, ArrayOrArithContainer
54+
4355

4456
# {{{ with_container_arithmetic
4557

@@ -402,8 +414,9 @@ def wrap(cls: Any) -> Any:
402414
warn(
403415
f"Broadcasting array context array types across {cls} "
404416
"has been explicitly "
405-
"enabled. As of 2025, this will stop working. "
406-
"There is no replacement as of right now. "
417+
"enabled. As of 2026, this will stop working. "
418+
"Use the arraycontext.Bcast* object wrappers for "
419+
"roughly equivalent functionality. "
407420
"See the discussion in "
408421
"https://github.com/inducer/arraycontext/pull/190. "
409422
"To opt out now (and avoid this warning), "
@@ -413,8 +426,9 @@ def wrap(cls: Any) -> Any:
413426
warn(
414427
f"Broadcasting array context array types across {cls} "
415428
"has been implicitly "
416-
"enabled. As of 2025, this will no longer work. "
417-
"There is no replacement as of right now. "
429+
"enabled. As of 2026, this will no longer work. "
430+
"Use the arraycontext.Bcast* object wrappers for "
431+
"roughly equivalent functionality. "
418432
"See the discussion in "
419433
"https://github.com/inducer/arraycontext/pull/190. "
420434
"To opt out now (and avoid this warning), "
@@ -603,8 +617,9 @@ def {fname}(arg1):
603617
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
604618
warn("Broadcasting {cls} over array "
605619
f"context array type {{type(arg2)}} is deprecated "
606-
"and will no longer work in 2025. "
607-
"There is no replacement as of right now. "
620+
"and will no longer work in 2026. "
621+
"Use the arraycontext.Bcast* object wrappers for "
622+
"roughly equivalent functionality. "
608623
"See the discussion in "
609624
"https://github.com/inducer/arraycontext/"
610625
"pull/190. ",
@@ -654,8 +669,10 @@ def {fname}(arg2, arg1):
654669
warn("Broadcasting {cls} over array "
655670
f"context array type {{type(arg1)}} "
656671
"is deprecated "
657-
"and will no longer work in 2025."
658-
"There is no replacement as of right now. "
672+
"and will no longer work in 2026."
673+
"Use the arraycontext.Bcast* object "
674+
"rappers for roughly equivalent "
675+
"functionality. "
659676
"See the discussion in "
660677
"https://github.com/inducer/arraycontext/"
661678
"pull/190. ",
@@ -687,4 +704,71 @@ def {fname}(arg2, arg1):
687704
# }}}
688705

689706

707+
@dataclass(frozen=True)
708+
class BcastUntilActxArray:
709+
array_context: ArrayContext
710+
broadcastee: ArrayOrArithContainer
711+
712+
def _binary_op(self,
713+
op: Callable[
714+
[ArrayOrArithContainer, ArrayOrArithContainer],
715+
ArrayOrArithContainer
716+
],
717+
right: ArrayOrArithContainer
718+
) -> ArrayOrArithContainer:
719+
try:
720+
serialized = serialize_container(right)
721+
except NotAnArrayContainerError:
722+
return op(self.broadcastee, right)
723+
724+
return deserialize_container(right, [
725+
(k, op(self.broadcastee, right_v)
726+
if isinstance(right_v, self.array_context.array_types) else
727+
self._binary_op(op, right_v)
728+
)
729+
for k, right_v in serialized])
730+
731+
def _rev_binary_op(self,
732+
op: Callable[
733+
[ArrayOrArithContainer, ArrayOrArithContainer],
734+
ArrayOrArithContainer
735+
],
736+
left: ArrayOrArithContainer
737+
) -> ArrayOrArithContainer:
738+
try:
739+
serialized = serialize_container(left)
740+
except NotAnArrayContainerError:
741+
return op(left, self.broadcastee)
742+
743+
return deserialize_container(left, [
744+
(k, op(left_v, self.broadcastee)
745+
if isinstance(left_v, self.array_context.array_types) else
746+
self._rev_binary_op(op, left_v)
747+
)
748+
for k, left_v in serialized])
749+
750+
__add__ = partialmethod(_binary_op, operator.add)
751+
__radd__ = partialmethod(_rev_binary_op, operator.add)
752+
__sub__ = partialmethod(_binary_op, operator.sub)
753+
__rsub__ = partialmethod(_rev_binary_op, operator.sub)
754+
__mul__ = partialmethod(_binary_op, operator.mul)
755+
__rmul__ = partialmethod(_rev_binary_op, operator.mul)
756+
__truediv__ = partialmethod(_binary_op, operator.truediv)
757+
__rtruediv__ = partialmethod(_rev_binary_op, operator.truediv)
758+
__floordiv__ = partialmethod(_binary_op, operator.floordiv)
759+
__rfloordiv__ = partialmethod(_rev_binary_op, operator.floordiv)
760+
__mod__ = partialmethod(_binary_op, operator.mod)
761+
__rmod__ = partialmethod(_rev_binary_op, operator.mod)
762+
__pow__ = partialmethod(_binary_op, operator.pow)
763+
__rpow__ = partialmethod(_rev_binary_op, operator.pow)
764+
765+
__lshift__ = partialmethod(_binary_op, operator.lshift)
766+
__rlshift__ = partialmethod(_rev_binary_op, operator.lshift)
767+
__rshift__ = partialmethod(_binary_op, operator.rshift)
768+
__rrshift__ = partialmethod(_rev_binary_op, operator.rshift)
769+
__and__ = partialmethod(_binary_op, operator.and_)
770+
__rand__ = partialmethod(_rev_binary_op, operator.and_)
771+
__or__ = partialmethod(_binary_op, operator.or_)
772+
__ror__ = partialmethod(_rev_binary_op, operator.or_)
773+
690774
# vim: foldmethod=marker

0 commit comments

Comments
 (0)