Skip to content

Commit 50ec2f1

Browse files
committed
Add bcast and bcast_until_actx_array functions
1 parent 4aeaed4 commit 50ec2f1

File tree

3 files changed

+93
-14
lines changed

3 files changed

+93
-14
lines changed

arraycontext/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@
5050
)
5151
from .container.dataclass import dataclass_array_container
5252
from .container.traversal import (
53+
bcast_left,
54+
bcast_left_until_actx_array,
55+
bcast_right,
56+
bcast_right_until_actx_array,
5357
flat_size_and_dtype,
5458
flatten,
5559
freeze,
@@ -129,6 +133,10 @@
129133
"ScalarLike",
130134
"SerializationKey",
131135
"SerializedContainer",
136+
"bcast_left",
137+
"bcast_left_until_actx_array",
138+
"bcast_right",
139+
"bcast_right_until_actx_array",
132140
"dataclass_array_container",
133141
"deserialize_container",
134142
"flat_size_and_dtype",

arraycontext/container/arithmetic.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,8 @@ def wrap(cls: Any) -> Any:
413413
warn(
414414
f"Broadcasting array context array types across {cls} "
415415
"has been implicitly "
416-
"enabled. As of 2025, this will no longer work. "
417-
"There is no replacement as of right now. "
418-
"See the discussion in "
419-
"https://github.com/inducer/arraycontext/pull/190. "
416+
"enabled. As of 2026, this will no longer work. "
417+
"Use bcast_left/right_until_actx_array instead."
420418
"To opt out now (and avoid this warning), "
421419
"pass _bcast_actx_array_type=False.",
422420
DeprecationWarning, stacklevel=2)
@@ -603,11 +601,9 @@ def {fname}(arg1):
603601
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
604602
warn("Broadcasting {cls} over array "
605603
f"context array type {{type(arg2)}} is deprecated "
606-
"and will no longer work in 2025. "
607-
"There is no replacement as of right now. "
608-
"See the discussion in "
609-
"https://github.com/inducer/arraycontext/"
610-
"pull/190. ",
604+
"and will no longer work in 2026. "
605+
"Use bcast_left/right_until_actx_array "
606+
"instead.",
611607
DeprecationWarning, stacklevel=2)
612608
613609
return cls({bcast_init_args_arg1_is_outer})
@@ -654,11 +650,9 @@ def {fname}(arg2, arg1):
654650
warn("Broadcasting {cls} over array "
655651
f"context array type {{type(arg1)}} "
656652
"is deprecated "
657-
"and will no longer work in 2025."
658-
"There is no replacement as of right now. "
659-
"See the discussion in "
660-
"https://github.com/inducer/arraycontext/"
661-
"pull/190. ",
653+
"and will no longer work in 2026."
654+
"Use bcast_left/right_until_actx_array "
655+
"instead.",
662656
DeprecationWarning, stacklevel=2)
663657
664658
return cls({bcast_init_args_arg2_is_outer})

arraycontext/container/traversal.py

+77
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import numpy as np
7979

8080
from arraycontext.container import (
81+
ArithArrayContainer,
8182
ArrayContainer,
8283
NotAnArrayContainerError,
8384
SerializationKey,
@@ -88,6 +89,7 @@
8889
from arraycontext.context import (
8990
Array,
9091
ArrayContext,
92+
ArrayOrArithContainer,
9193
ArrayOrContainer,
9294
ArrayOrContainerOrScalar,
9395
ArrayOrContainerT,
@@ -987,4 +989,79 @@ def treat_as_scalar(x: Any) -> bool:
987989

988990
# }}}
989991

992+
993+
# {{{
994+
995+
def bcast_left(
996+
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
997+
ArrayOrArithContainer],
998+
left: ArrayOrArithContainer,
999+
right: ArithArrayContainer,
1000+
) -> ArrayOrArithContainer:
1001+
try:
1002+
serialized = serialize_container(right)
1003+
except NotAnArrayContainerError:
1004+
return op(left, right)
1005+
1006+
return deserialize_container(right, [
1007+
(k, op(left, right_v)) for k, right_v in serialized])
1008+
1009+
1010+
def bcast_right(
1011+
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
1012+
ArrayOrArithContainer],
1013+
left: ArrayOrArithContainer,
1014+
right: ArithArrayContainer,
1015+
) -> ArrayOrArithContainer:
1016+
try:
1017+
serialized = serialize_container(left)
1018+
except NotAnArrayContainerError:
1019+
return op(left, right)
1020+
1021+
return deserialize_container(right, [
1022+
(k, op(left_v, right)) for k, left_v in serialized])
1023+
1024+
1025+
def bcast_left_until_actx_array(
1026+
actx: ArrayContext,
1027+
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
1028+
ArrayOrArithContainer],
1029+
left: ArrayOrArithContainer,
1030+
right: ArithArrayContainer,
1031+
) -> ArrayOrArithContainer:
1032+
try:
1033+
serialized = serialize_container(right)
1034+
except NotAnArrayContainerError:
1035+
return op(left, right)
1036+
1037+
return deserialize_container(right, [
1038+
(k, op(left, right_v)
1039+
if isinstance(right_v, actx.array_types) else
1040+
bcast_left_until_actx_array(actx, op, left, right_v)
1041+
)
1042+
for k, right_v in serialized])
1043+
1044+
1045+
def bcast_right_until_actx_array(
1046+
actx: ArrayContext,
1047+
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
1048+
ArrayOrArithContainer],
1049+
left: ArrayOrArithContainer,
1050+
right: ArithArrayContainer,
1051+
) -> ArrayOrArithContainer:
1052+
try:
1053+
serialized = serialize_container(left)
1054+
except NotAnArrayContainerError:
1055+
return op(left, right)
1056+
1057+
return deserialize_container(right, [
1058+
(k, op(left_v, right)
1059+
if isinstance(left_v, actx.array_types) else
1060+
bcast_right_until_actx_array(actx, op, left_v, right)
1061+
)
1062+
for k, left_v in serialized])
1063+
1064+
# }}}
1065+
1066+
9901067
# vim: foldmethod=marker

0 commit comments

Comments
 (0)