-
Notifications
You must be signed in to change notification settings - Fork 11
Add bcast
and bcast_until_actx_array functions
#307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,6 +78,7 @@ | |
import numpy as np | ||
|
||
from arraycontext.container import ( | ||
ArithArrayContainer, | ||
ArrayContainer, | ||
NotAnArrayContainerError, | ||
SerializationKey, | ||
|
@@ -88,6 +89,7 @@ | |
from arraycontext.context import ( | ||
Array, | ||
ArrayContext, | ||
ArrayOrArithContainer, | ||
ArrayOrContainer, | ||
ArrayOrContainerOrScalar, | ||
ArrayOrContainerT, | ||
|
@@ -987,4 +989,79 @@ def treat_as_scalar(x: Any) -> bool: | |
|
||
# }}} | ||
|
||
|
||
# {{{ | ||
inducer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def bcast_left( | ||
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], | ||
ArrayOrArithContainer], | ||
left: ArrayOrArithContainer, | ||
right: ArithArrayContainer, | ||
) -> ArrayOrArithContainer: | ||
try: | ||
serialized = serialize_container(right) | ||
except NotAnArrayContainerError: | ||
return op(left, right) | ||
|
||
return deserialize_container(right, [ | ||
(k, op(left, right_v)) for k, right_v in serialized]) | ||
|
||
|
||
def bcast_right( | ||
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], | ||
ArrayOrArithContainer], | ||
left: ArrayOrArithContainer, | ||
right: ArithArrayContainer, | ||
) -> ArrayOrArithContainer: | ||
try: | ||
serialized = serialize_container(left) | ||
except NotAnArrayContainerError: | ||
return op(left, right) | ||
|
||
return deserialize_container(right, [ | ||
(k, op(left_v, right)) for k, left_v in serialized]) | ||
|
||
|
||
def bcast_left_until_actx_array( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naming? (May not matter as much, as it's convenient to make local aliases, as in the grudge PR.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I agree that local aliases are probably needed, so this should be nice and verbose.. i.e. I'm quite fine with the name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Argument order? (Here and above.) I initially had the operator in the middle but went with "first" to allow convenient use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the functional feel of it too 👍 |
||
actx: ArrayContext, | ||
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], | ||
ArrayOrArithContainer], | ||
left: ArrayOrArithContainer, | ||
right: ArithArrayContainer, | ||
) -> ArrayOrArithContainer: | ||
Comment on lines
+1026
to
+1031
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to have the interface for these be more like the container traversal functions, e.g.: def rec_bcast_left(op, left, right, leaf_cls: type | None = None):
... and then mul = partial(rec_bcast_left, operator.mul, leaf_cls=actx.array_types) ? |
||
try: | ||
serialized = serialize_container(right) | ||
except NotAnArrayContainerError: | ||
return op(left, right) | ||
|
||
return deserialize_container(right, [ | ||
(k, op(left, right_v) | ||
if isinstance(right_v, actx.array_types) else | ||
bcast_left_until_actx_array(actx, op, left, right_v) | ||
) | ||
for k, right_v in serialized]) | ||
|
||
|
||
def bcast_right_until_actx_array( | ||
actx: ArrayContext, | ||
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], | ||
ArrayOrArithContainer], | ||
left: ArrayOrArithContainer, | ||
right: ArithArrayContainer, | ||
) -> ArrayOrArithContainer: | ||
try: | ||
serialized = serialize_container(left) | ||
except NotAnArrayContainerError: | ||
return op(left, right) | ||
|
||
return deserialize_container(right, [ | ||
(k, op(left_v, right) | ||
if isinstance(left_v, actx.array_types) else | ||
bcast_right_until_actx_array(actx, op, left_v, right) | ||
) | ||
for k, left_v in serialized]) | ||
|
||
# }}} | ||
|
||
|
||
# vim: foldmethod=marker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.