6
6
.. currentmodule:: arraycontext
7
7
8
8
.. autofunction:: with_container_arithmetic
9
+
10
+ .. autoclass:: BcastUntilActxArray
9
11
"""
10
12
11
13
34
36
"""
35
37
36
38
import enum
39
+ import operator
37
40
from collections .abc import Callable
41
+ from dataclasses import dataclass , field
42
+ from functools import partialmethod
43
+ from numbers import Number
38
44
from typing import Any , TypeVar
39
45
from warnings import warn
40
46
41
47
import numpy as np
42
48
49
+ from arraycontext .container import (
50
+ NotAnArrayContainerError ,
51
+ deserialize_container ,
52
+ serialize_container ,
53
+ )
54
+ from arraycontext .context import ArrayContext , ArrayOrContainer
55
+
43
56
44
57
# {{{ with_container_arithmetic
45
58
@@ -142,8 +155,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
142
155
warn (
143
156
"Broadcasting container against non-object numpy array. "
144
157
"This was never documented to work and will now stop working in "
145
- "2025. Convert the array to an object array to preserve the "
146
- "current semantics." , DeprecationWarning , stacklevel = 3 )
158
+ "2025. Convert the array to an object array or use "
159
+ "arraycontext.BcastUntilActxArray (or similar) to obtain the desired "
160
+ "broadcasting semantics." , DeprecationWarning , stacklevel = 3 )
147
161
return True
148
162
else :
149
163
return False
@@ -207,6 +221,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
207
221
208
222
Each operator class also includes the "reverse" operators if applicable.
209
223
224
+ .. note::
225
+
226
+ For the generated binary arithmetic operators, if certain types
227
+ should be broadcast over the container (with the container as the
228
+ 'outer' structure) but are not handled in this way by their types,
229
+ you may wrap them in :class:`BcastUntilActxArray` to achieve
230
+ the desired semantics.
231
+
210
232
.. note::
211
233
212
234
To generate the code implementing the operators, this function relies on
@@ -402,8 +424,9 @@ def wrap(cls: Any) -> Any:
402
424
warn (
403
425
f"Broadcasting array context array types across { cls } "
404
426
"has been explicitly "
405
- "enabled. As of 2025, this will stop working. "
406
- "There is no replacement as of right now. "
427
+ "enabled. As of 2026, this will stop working. "
428
+ "Use arraycontext.Bcast* object wrappers for "
429
+ "roughly equivalent functionality. "
407
430
"See the discussion in "
408
431
"https://github.com/inducer/arraycontext/pull/190. "
409
432
"To opt out now (and avoid this warning), "
@@ -413,8 +436,9 @@ def wrap(cls: Any) -> Any:
413
436
warn (
414
437
f"Broadcasting array context array types across { cls } "
415
438
"has been implicitly "
416
- "enabled. As of 2025, this will no longer work. "
417
- "There is no replacement as of right now. "
439
+ "enabled. As of 2026, this will no longer work. "
440
+ "Use arraycontext.Bcast* object wrappers for "
441
+ "roughly equivalent functionality. "
418
442
"See the discussion in "
419
443
"https://github.com/inducer/arraycontext/pull/190. "
420
444
"To opt out now (and avoid this warning), "
@@ -603,8 +627,9 @@ def {fname}(arg1):
603
627
if isinstance(arg2, { tup_str (bcast_actx_ary_types )} ):
604
628
warn("Broadcasting { cls } over array "
605
629
f"context array type {{type(arg2)}} is deprecated "
606
- "and will no longer work in 2025. "
607
- "There is no replacement as of right now. "
630
+ "and will no longer work in 2026. "
631
+ "Use arraycontext.Bcast* object wrappers for "
632
+ "roughly equivalent functionality. "
608
633
"See the discussion in "
609
634
"https://github.com/inducer/arraycontext/"
610
635
"pull/190. ",
@@ -654,8 +679,10 @@ def {fname}(arg2, arg1):
654
679
warn("Broadcasting { cls } over array "
655
680
f"context array type {{type(arg1)}} "
656
681
"is deprecated "
657
- "and will no longer work in 2025."
658
- "There is no replacement as of right now. "
682
+ "and will no longer work in 2026."
683
+ "Use arraycontext.Bcast* object "
684
+ "wrappers for roughly equivalent "
685
+ "functionality. "
659
686
"See the discussion in "
660
687
"https://github.com/inducer/arraycontext/"
661
688
"pull/190. ",
@@ -687,4 +714,111 @@ def {fname}(arg2, arg1):
687
714
# }}}
688
715
689
716
717
+ # {{{ Bcast object-ified broadcast rules
718
+
719
+ # Possible advantages of the "Bcast" broadcast-rule-as-object design:
720
+ #
721
+ # - If one rule does not fit the user's need, they can straightforwardly use
722
+ # another.
723
+ #
724
+ # - It's straightforward to find where certain broadcast rules are used.
725
+ #
726
+ # - The broadcast rule can contain more state. For example, it's now easy
727
+ # for the rule to know what array context should be used to determine
728
+ # actx array types.
729
+ #
730
+ # Possible downsides of the "Bcast" broadcast-rule-as-object design:
731
+ #
732
+ # - User code is a bit more wordy.
733
+
734
+ @dataclass (frozen = True )
735
+ class BcastUntilActxArray :
736
+ """
737
+ An operator-overloading wrapper around an object (*broadcastee*) that should be
738
+ broadcast across array containers until the 'opposite' operand is one of the
739
+ :attr:`~arraycontext.ArrayContext.array_types`
740
+ of *actx* or a :class:`~numbers.Number`.
741
+
742
+ Suggested usage pattern::
743
+
744
+ bcast = functools.partial(BcastUntilActxArray, actx)
745
+
746
+ container + bcast(actx_array)
747
+
748
+ .. automethod:: __init__
749
+ """
750
+
751
+ array_context : ArrayContext
752
+ broadcastee : ArrayOrContainer
753
+
754
+ _stop_types : tuple [type , ...] = field (init = False )
755
+
756
+ def __post_init__ (self ) -> None :
757
+ object .__setattr__ (
758
+ self , "_stop_types" , (* self .array_context .array_types , Number ))
759
+
760
+ def _binary_op (self ,
761
+ op : Callable [
762
+ [ArrayOrContainer , ArrayOrContainer ],
763
+ ArrayOrContainer
764
+ ],
765
+ right : ArrayOrContainer
766
+ ) -> ArrayOrContainer :
767
+ try :
768
+ serialized = serialize_container (right )
769
+ except NotAnArrayContainerError :
770
+ return op (self .broadcastee , right )
771
+
772
+ return deserialize_container (right , [
773
+ (k , op (self .broadcastee , right_v )
774
+ if isinstance (right_v , self ._stop_types ) else
775
+ self ._binary_op (op , right_v )
776
+ )
777
+ for k , right_v in serialized ])
778
+
779
+ def _rev_binary_op (self ,
780
+ op : Callable [
781
+ [ArrayOrContainer , ArrayOrContainer ],
782
+ ArrayOrContainer
783
+ ],
784
+ left : ArrayOrContainer
785
+ ) -> ArrayOrContainer :
786
+ try :
787
+ serialized = serialize_container (left )
788
+ except NotAnArrayContainerError :
789
+ return op (left , self .broadcastee )
790
+
791
+ return deserialize_container (left , [
792
+ (k , op (left_v , self .broadcastee )
793
+ if isinstance (left_v , self ._stop_types ) else
794
+ self ._rev_binary_op (op , left_v )
795
+ )
796
+ for k , left_v in serialized ])
797
+
798
+ __add__ = partialmethod (_binary_op , operator .add )
799
+ __radd__ = partialmethod (_rev_binary_op , operator .add )
800
+ __sub__ = partialmethod (_binary_op , operator .sub )
801
+ __rsub__ = partialmethod (_rev_binary_op , operator .sub )
802
+ __mul__ = partialmethod (_binary_op , operator .mul )
803
+ __rmul__ = partialmethod (_rev_binary_op , operator .mul )
804
+ __truediv__ = partialmethod (_binary_op , operator .truediv )
805
+ __rtruediv__ = partialmethod (_rev_binary_op , operator .truediv )
806
+ __floordiv__ = partialmethod (_binary_op , operator .floordiv )
807
+ __rfloordiv__ = partialmethod (_rev_binary_op , operator .floordiv )
808
+ __mod__ = partialmethod (_binary_op , operator .mod )
809
+ __rmod__ = partialmethod (_rev_binary_op , operator .mod )
810
+ __pow__ = partialmethod (_binary_op , operator .pow )
811
+ __rpow__ = partialmethod (_rev_binary_op , operator .pow )
812
+
813
+ __lshift__ = partialmethod (_binary_op , operator .lshift )
814
+ __rlshift__ = partialmethod (_rev_binary_op , operator .lshift )
815
+ __rshift__ = partialmethod (_binary_op , operator .rshift )
816
+ __rrshift__ = partialmethod (_rev_binary_op , operator .rshift )
817
+ __and__ = partialmethod (_binary_op , operator .and_ )
818
+ __rand__ = partialmethod (_rev_binary_op , operator .and_ )
819
+ __or__ = partialmethod (_binary_op , operator .or_ )
820
+ __ror__ = partialmethod (_rev_binary_op , operator .or_ )
821
+
822
+ # }}}
823
+
690
824
# vim: foldmethod=marker
0 commit comments