Skip to content

Add bcast and bcast_until_actx_array functions #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
)
from .container.dataclass import dataclass_array_container
from .container.traversal import (
bcast_left,
bcast_left_until_actx_array,
bcast_right,
bcast_right_until_actx_array,
flat_size_and_dtype,
flatten,
freeze,
Expand Down Expand Up @@ -129,6 +133,10 @@
"ScalarLike",
"SerializationKey",
"SerializedContainer",
"bcast_left",
"bcast_left_until_actx_array",
"bcast_right",
"bcast_right_until_actx_array",
"dataclass_array_container",
"deserialize_container",
"flat_size_and_dtype",
Expand Down
22 changes: 8 additions & 14 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,8 @@ def wrap(cls: Any) -> Any:
warn(
f"Broadcasting array context array types across {cls} "
"has been implicitly "
"enabled. As of 2025, this will no longer work. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/pull/190. "
"enabled. As of 2026, this will no longer work. "
"Use bcast_left/right_until_actx_array instead."
"To opt out now (and avoid this warning), "
"pass _bcast_actx_array_type=False.",
DeprecationWarning, stacklevel=2)
Expand Down Expand Up @@ -603,11 +601,9 @@ def {fname}(arg1):
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg2)}} is deprecated "
"and will no longer work in 2025. "
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
"and will no longer work in 2026. "
"Use bcast_left/right_until_actx_array "
"instead.",
DeprecationWarning, stacklevel=2)

return cls({bcast_init_args_arg1_is_outer})
Expand Down Expand Up @@ -654,11 +650,9 @@ def {fname}(arg2, arg1):
warn("Broadcasting {cls} over array "
f"context array type {{type(arg1)}} "
"is deprecated "
"and will no longer work in 2025."
"There is no replacement as of right now. "
"See the discussion in "
"https://github.com/inducer/arraycontext/"
"pull/190. ",
"and will no longer work in 2026."
"Use bcast_left/right_until_actx_array "
"instead.",
DeprecationWarning, stacklevel=2)

return cls({bcast_init_args_arg2_is_outer})
Expand Down
77 changes: 77 additions & 0 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import numpy as np

from arraycontext.container import (
ArithArrayContainer,
ArrayContainer,
NotAnArrayContainerError,
SerializationKey,
Expand All @@ -88,6 +89,7 @@
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrArithContainer,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerT,
Expand Down Expand Up @@ -987,4 +989,79 @@ def treat_as_scalar(x: Any) -> bool:

# }}}


# {{{
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# {{{
# {{{ bcast


def bcast_left(
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
ArrayOrArithContainer],
left: ArrayOrArithContainer,
right: ArithArrayContainer,
) -> ArrayOrArithContainer:
try:
serialized = serialize_container(right)
except NotAnArrayContainerError:
return op(left, right)

return deserialize_container(right, [
(k, op(left, right_v)) for k, right_v in serialized])


def bcast_right(
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
ArrayOrArithContainer],
left: ArrayOrArithContainer,
right: ArithArrayContainer,
) -> ArrayOrArithContainer:
try:
serialized = serialize_container(left)
except NotAnArrayContainerError:
return op(left, right)

return deserialize_container(right, [
(k, op(left_v, right)) for k, left_v in serialized])


def bcast_left_until_actx_array(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming? (May not matter as much, as it's convenient to make local aliases, as in the grudge PR.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rec_bcast_left? Mostly to match some of the other traversal functions.

I agree that local aliases are probably needed, so this should be nice and verbose.. i.e. I'm quite fine with the name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with adding rec_, and I would suggest maybe adding something else to disambiguate what's happening to the "left". (Picturing myself looking at this again in 6 months and not being sure if it means "broadcast the left across" or "broadcast across the left".) rec_bcast_left_operand_across_actx_arrays or something (probably too verbose, maybe you can think of something better).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument order? (Here and above.) I initially had the operator in the middle but went with "first" to allow convenient use of partial.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the functional feel of it too 👍

actx: ArrayContext,
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
ArrayOrArithContainer],
left: ArrayOrArithContainer,
right: ArithArrayContainer,
) -> ArrayOrArithContainer:
Comment on lines +1026 to +1031
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have the interface for these be more like the container traversal functions, e.g.:

def rec_bcast_left(op, left, right, leaf_cls: type | None = None):
    ...

and then

mul = partial(rec_bcast_left, operator.mul, leaf_cls=actx.array_types)

?

try:
serialized = serialize_container(right)
except NotAnArrayContainerError:
return op(left, right)

return deserialize_container(right, [
(k, op(left, right_v)
if isinstance(right_v, actx.array_types) else
bcast_left_until_actx_array(actx, op, left, right_v)
)
for k, right_v in serialized])


def bcast_right_until_actx_array(
actx: ArrayContext,
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
ArrayOrArithContainer],
left: ArrayOrArithContainer,
right: ArithArrayContainer,
) -> ArrayOrArithContainer:
try:
serialized = serialize_container(left)
except NotAnArrayContainerError:
return op(left, right)

return deserialize_container(right, [
(k, op(left_v, right)
if isinstance(left_v, actx.array_types) else
bcast_right_until_actx_array(actx, op, left_v, right)
)
for k, left_v in serialized])

# }}}


# vim: foldmethod=marker
Loading