6
6
.. currentmodule:: arraycontext
7
7
8
8
.. autofunction:: with_container_arithmetic
9
+ .. autoclass:: Bcast
10
+ .. autoclass:: BcastNLevels
11
+ .. autoclass:: BcastUntilActxArray
12
+
13
+ .. function:: Bcast1
14
+
15
+ Like :class:`BcastNLevels` with *nlevels* set to 1.
16
+
17
+ .. function:: Bcast2
18
+
19
+ Like :class:`BcastNLevels` with *nlevels* set to 2.
20
+
21
+ .. function:: Bcast3
22
+
23
+ Like :class:`BcastNLevels` with *nlevels* set to 3.
9
24
"""
10
25
11
26
34
49
"""
35
50
36
51
import enum
52
+ from abc import ABC , abstractmethod
37
53
from collections .abc import Callable
38
- from typing import Any , TypeVar
54
+ from dataclasses import FrozenInstanceError
55
+ from functools import partial
56
+ from numbers import Number
57
+ from typing import Any , ClassVar , TypeVar , Union
39
58
from warnings import warn
40
59
41
60
import numpy as np
42
61
62
+ from arraycontext .context import ArrayContext , ArrayOrContainer
63
+
43
64
44
65
# {{{ with_container_arithmetic
45
66
@@ -142,8 +163,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
142
163
warn (
143
164
"Broadcasting container against non-object numpy array. "
144
165
"This was never documented to work and will now stop working in "
145
- "2025. Convert the array to an object array to preserve the "
146
- "current semantics." , DeprecationWarning , stacklevel = 3 )
166
+ "2025. Convert the array to an object array or use "
167
+ "variants of arraycontext.Bcast to obtain the desired "
168
+ "broadcasting semantics." , DeprecationWarning , stacklevel = 3 )
147
169
return True
148
170
else :
149
171
return False
@@ -153,6 +175,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
153
175
pass
154
176
155
177
178
+ class Bcast :
179
+ """
180
+ A wrapper object to force arithmetic generated by :func:`with_container_arithmetic`
181
+ to broadcast *arg* across a container (with the container as the 'outer' structure).
182
+ Since array containers are often nested in complex ways, different subclasses
183
+ implement different rules on how broadcasting interacts with the hierarchy,
184
+ with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing
185
+ the most common.
186
+ """
187
+ arg : ArrayOrContainer
188
+
189
+ # Accessing this attribute is cheaper than isinstance, so use that
190
+ # to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand.
191
+ _with_next_operand : ClassVar [bool ]
192
+
193
+ def __init__ (self , arg : ArrayOrContainer ) -> None :
194
+ object .__setattr__ (self , "arg" , arg )
195
+
196
+ def __setattr__ (self , name : str , value : Any ) -> None :
197
+ raise FrozenInstanceError ()
198
+
199
+ def __delattr__ (self , name : str ) -> None :
200
+ raise FrozenInstanceError ()
201
+
202
+
203
+ class _BcastWithNextOperand (Bcast , ABC ):
204
+ """
205
+ A :class:`Bcast` object that gets to see who the next operand will be, in
206
+ order to decide whether wrapping the child in :class:`Bcast` is still necessary.
207
+ This is much more flexible, but also considerably more expensive, than
208
+ :class:`_BcastWithoutNextOperand`.
209
+ """
210
+
211
+ _with_next_operand = True
212
+
213
+ # purposefully undocumented
214
+ @abstractmethod
215
+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
216
+ ...
217
+
218
+
219
+ class _BcastWithoutNextOperand (Bcast , ABC ):
220
+ """
221
+ A :class:`Bcast` object that does not get to see who the next operand will be.
222
+ """
223
+ _with_next_operand = False
224
+
225
+ # purposefully undocumented
226
+ @abstractmethod
227
+ def _rewrap (self ) -> ArrayOrContainer :
228
+ ...
229
+
230
+
231
+ class BcastNLevels (_BcastWithoutNextOperand ):
232
+ """
233
+ A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of
234
+ array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as
235
+ convenient aliases for the common cases.
236
+
237
+ Usage example::
238
+
239
+ container + Bcast2(actx_array)
240
+
241
+ .. note::
242
+
243
+ :mod:`numpy` object arrays do not count against the number of levels.
244
+
245
+ .. automethod:: __init__
246
+ """
247
+ nlevels : int
248
+
249
+ def __init__ (self , nlevels : int , arg : ArrayOrContainer ) -> None :
250
+ if nlevels < 1 :
251
+ raise ValueError ("nlevels is expected to be one or greater." )
252
+
253
+ super ().__init__ (arg )
254
+ object .__setattr__ (self , "nlevels" , nlevels )
255
+
256
+ def _rewrap (self ) -> ArrayOrContainer :
257
+ if self .nlevels == 1 :
258
+ return self .arg
259
+ else :
260
+ return BcastNLevels (self .nlevels - 1 , self .arg )
261
+
262
+
263
+ Bcast1Level = partial (BcastNLevels , 1 )
264
+ Bcast2Levels = partial (BcastNLevels , 2 )
265
+ Bcast3Levels = partial (BcastNLevels , 3 )
266
+
267
+
268
+ class BcastUntilActxArray (_BcastWithNextOperand ):
269
+ """
270
+ A broadcast rule that broadcasts *arg* across array containers until
271
+ the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
272
+ of *actx*, or a :class:`~numbers.Number`.
273
+
274
+ Suggested usage pattern::
275
+
276
+ bcast = functools.partial(BcastUntilActxArray, actx)
277
+
278
+ container + bcast(actx_array)
279
+
280
+ .. automethod:: __init__
281
+ """
282
+ actx : ArrayContext
283
+
284
+ def __init__ (self ,
285
+ actx : ArrayContext ,
286
+ arg : ArrayOrContainer ) -> None :
287
+ super ().__init__ (arg )
288
+ object .__setattr__ (self , "actx" , actx )
289
+
290
+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
291
+ if isinstance (other_operand , (* self .actx .array_types , Number )):
292
+ return self .arg
293
+ else :
294
+ return self
295
+
296
+
156
297
def with_container_arithmetic (
157
298
* ,
158
299
number_bcasts_across : bool | None = None ,
@@ -207,6 +348,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
207
348
208
349
Each operator class also includes the "reverse" operators if applicable.
209
350
351
+ .. note::
352
+
353
+ For the generated binary arithmetic operators, if certain types
354
+ should be broadcast over the container (with the container as the
355
+ 'outer' structure) but are not handled in this way by their types,
356
+ you may wrap them in one of the :class:`Bcast` variants to achieve
357
+ the desired semantics.
358
+
210
359
.. note::
211
360
212
361
To generate the code implementing the operators, this function relies on
@@ -239,6 +388,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
239
388
#
240
389
# - Broadcast rules are hard to change once established, particularly
241
390
# because one cannot grep for their use.
391
+ #
392
+ # Possible advantages of the "Bcast" broadcast-rule-as-object design:
393
+ #
394
+ # - If one rule does not fit the user's need, they can straightforwardly use
395
+ # another.
396
+ #
397
+ # - It's straightforward to find where certain broadcast rules are used.
398
+ #
399
+ # - The broadcast rule can contain more state. For example, it's now easy
400
+ # for the rule to know what array context should be used to determine
401
+ # actx array types.
402
+ #
403
+ # Possible downsides of the "Bcast" broadcast-rule-as-object design:
404
+ #
405
+ # - User code is a bit more wordy.
406
+ #
407
+ # - Rewrapping has the potential to be costly, especially in
408
+ # _with_next_operand mode.
242
409
243
410
# {{{ handle inputs
244
411
@@ -404,9 +571,8 @@ def wrap(cls: Any) -> Any:
404
571
f"Broadcasting array context array types across { cls } "
405
572
"has been explicitly "
406
573
"enabled. As of 2025, this will stop working. "
407
- "There is no replacement as of right now. "
408
- "See the discussion in "
409
- "https://github.com/inducer/arraycontext/pull/190. "
574
+ "Express these operations using arraycontext.Bcast variants "
575
+ "instead. "
410
576
"To opt out now (and avoid this warning), "
411
577
"pass _bcast_actx_array_type=False. " ,
412
578
DeprecationWarning , stacklevel = 2 )
@@ -415,9 +581,8 @@ def wrap(cls: Any) -> Any:
415
581
f"Broadcasting array context array types across { cls } "
416
582
"has been implicitly "
417
583
"enabled. As of 2025, this will no longer work. "
418
- "There is no replacement as of right now. "
419
- "See the discussion in "
420
- "https://github.com/inducer/arraycontext/pull/190. "
584
+ "Express these operations using arraycontext.Bcast variants "
585
+ "instead. "
421
586
"To opt out now (and avoid this warning), "
422
587
"pass _bcast_actx_array_type=False." ,
423
588
DeprecationWarning , stacklevel = 2 )
@@ -435,7 +600,7 @@ def wrap(cls: Any) -> Any:
435
600
gen (f"""
436
601
from numbers import Number
437
602
import numpy as np
438
- from arraycontext import ArrayContainer
603
+ from arraycontext import ArrayContainer, Bcast
439
604
from warnings import warn
440
605
441
606
def _raise_if_actx_none(actx):
@@ -455,7 +620,8 @@ def is_numpy_array(arg):
455
620
"behavior will change in 2025. If you would like the "
456
621
"broadcasting behavior to stay the same, make sure "
457
622
"to convert the passed numpy array to an "
458
- "object array.",
623
+ "object array, or use arraycontext.Bcast to achieve "
624
+ "the desired broadcasting semantics.",
459
625
DeprecationWarning, stacklevel=3)
460
626
return True
461
627
else:
@@ -553,6 +719,33 @@ def {fname}(arg1):
553
719
cls ._serialize_init_arrays_code ("arg2" ).items ()
554
720
})
555
721
722
+ def get_operand (arg : Union [tuple [str , str ], str ]) -> str :
723
+ if isinstance (arg , tuple ):
724
+ entry , _container = arg
725
+ return entry
726
+ else :
727
+ return arg
728
+
729
+ bcast_init_args_arg1_is_outer_with_rewrap = \
730
+ cls ._deserialize_init_arrays_code ("arg1" , {
731
+ key_arg1 :
732
+ _format_binary_op_str (
733
+ op_str , expr_arg1 ,
734
+ f"arg2._rewrap({ get_operand (expr_arg1 )} )" )
735
+ for key_arg1 , expr_arg1 in
736
+ cls ._serialize_init_arrays_code ("arg1" ).items ()
737
+ })
738
+ bcast_init_args_arg2_is_outer_with_rewrap = \
739
+ cls ._deserialize_init_arrays_code ("arg2" , {
740
+ key_arg2 :
741
+ _format_binary_op_str (
742
+ op_str ,
743
+ f"arg1._rewrap({ get_operand (expr_arg2 )} )" ,
744
+ expr_arg2 )
745
+ for key_arg2 , expr_arg2 in
746
+ cls ._serialize_init_arrays_code ("arg2" ).items ()
747
+ })
748
+
556
749
# {{{ "forward" binary operators
557
750
558
751
gen (f"def { fname } (arg1, arg2):" )
@@ -605,14 +798,19 @@ def {fname}(arg1):
605
798
warn("Broadcasting { cls } over array "
606
799
f"context array type {{type(arg2)}} is deprecated "
607
800
"and will no longer work in 2025. "
608
- "There is no replacement as of right now. "
609
- "See the discussion in "
610
- "https://github.com/inducer/arraycontext/"
611
- "pull/190. ",
801
+ "Use arraycontext.Bcast to achieve the desired "
802
+ "broadcasting semantics.",
612
803
DeprecationWarning, stacklevel=2)
613
804
614
805
return cls({ bcast_init_args_arg1_is_outer } )
615
806
807
+ if isinstance(arg2, Bcast):
808
+ if arg2._with_next_operand:
809
+ return cls({ bcast_init_args_arg1_is_outer_with_rewrap } )
810
+ else:
811
+ arg2 = arg2._rewrap()
812
+ return cls({ bcast_init_args_arg1_is_outer } )
813
+
616
814
return NotImplemented
617
815
""" )
618
816
gen (f"cls.__{ dunder_name } __ = { fname } " )
@@ -656,14 +854,19 @@ def {fname}(arg2, arg1):
656
854
f"context array type {{type(arg1)}} "
657
855
"is deprecated "
658
856
"and will no longer work in 2025."
659
- "There is no replacement as of right now. "
660
- "See the discussion in "
661
- "https://github.com/inducer/arraycontext/"
662
- "pull/190. ",
857
+ "Use arraycontext.Bcast to achieve the "
858
+ "desired broadcasting semantics.",
663
859
DeprecationWarning, stacklevel=2)
664
860
665
861
return cls({ bcast_init_args_arg2_is_outer } )
666
862
863
+ if isinstance(arg1, Bcast):
864
+ if arg1._with_next_operand:
865
+ return cls({ bcast_init_args_arg2_is_outer_with_rewrap } )
866
+ else:
867
+ arg1 = arg1._rewrap()
868
+ return cls({ bcast_init_args_arg2_is_outer } )
869
+
667
870
return NotImplemented
668
871
669
872
cls.__r{ dunder_name } __ = { fname } """ )
0 commit comments