6
6
.. currentmodule:: arraycontext
7
7
8
8
.. autofunction:: with_container_arithmetic
9
+ .. autoclass:: Bcast
10
+ .. autoclass:: BcastNLevels
11
+ .. autoclass:: BcastUntilActxArray
12
+
13
+ .. function:: Bcast1
14
+
15
+ Like :class:`BcastNLevels` with *nlevels* set to 1.
16
+
17
+ .. function:: Bcast2
18
+
19
+ Like :class:`BcastNLevels` with *nlevels* set to 2.
20
+
21
+ .. function:: Bcast3
22
+
23
+ Like :class:`BcastNLevels` with *nlevels* set to 3.
9
24
"""
10
25
11
26
34
49
"""
35
50
36
51
import enum
37
- from typing import Any , Callable , Optional , Tuple , TypeVar , Union
52
+ from abc import ABC , abstractmethod
53
+ from dataclasses import FrozenInstanceError
54
+ from functools import partial
55
+ from numbers import Number
56
+ from typing import Any , Callable , ClassVar , Optional , Tuple , TypeVar , Union
38
57
from warnings import warn
39
58
40
59
import numpy as np
41
60
61
+ from arraycontext .context import ArrayContext , ArrayOrContainer
62
+
42
63
43
64
# {{{ with_container_arithmetic
44
65
@@ -147,8 +168,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
147
168
warn (
148
169
"Broadcasting container against non-object numpy array. "
149
170
"This was never documented to work and will now stop working in "
150
- "2025. Convert the array to an object array to preserve the "
151
- "current semantics." , DeprecationWarning , stacklevel = 3 )
171
+ "2025. Convert the array to an object array or use "
172
+ "variants of arraycontext.Bcast to obtain the desired "
173
+ "broadcasting semantics." , DeprecationWarning , stacklevel = 3 )
152
174
return True
153
175
else :
154
176
return False
@@ -158,6 +180,125 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
158
180
pass
159
181
160
182
183
+ class Bcast :
184
+ """
185
+ A wrapper object to force arithmetic generated by :func:`with_container_arithmetic`
186
+ to broadcast *arg* across a container (with the container as the 'outer' structure).
187
+ Since array containers are often nested in complex ways, different subclasses
188
+ implement different rules on how broadcasting interacts with the hierarchy,
189
+ with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing
190
+ the most common.
191
+ """
192
+ arg : ArrayOrContainer
193
+
194
+ # Accessing this attribute is cheaper than isinstance, so use that
195
+ # to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand.
196
+ _with_next_operand : ClassVar [bool ]
197
+
198
+ def __init__ (self , arg : ArrayOrContainer ) -> None :
199
+ object .__setattr__ (self , "arg" , arg )
200
+
201
+ def __setattr__ (self , name : str , value : Any ) -> None :
202
+ raise FrozenInstanceError ()
203
+
204
+ def __delattr__ (self , name : str ) -> None :
205
+ raise FrozenInstanceError ()
206
+
207
+
208
+ class _BcastWithNextOperand (Bcast , ABC ):
209
+ """
210
+ A :class:`Bcast` object that gets to see who the next operand will be, in
211
+ order to decide whether wrapping the child in :class:`Bcast` is still necessary.
212
+ This is much more flexible, but also considerably more expensive, than
213
+ :class:`_BcastWithoutNextOperand`.
214
+ """
215
+
216
+ _with_next_operand = True
217
+
218
+ # purposefully undocumented
219
+ @abstractmethod
220
+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
221
+ ...
222
+
223
+
224
+ class _BcastWithoutNextOperand (Bcast , ABC ):
225
+ """
226
+ A :class:`Bcast` object that does not get to see who the next operand will be.
227
+ """
228
+ _with_next_operand = False
229
+
230
+ # purposefully undocumented
231
+ @abstractmethod
232
+ def _rewrap (self ) -> ArrayOrContainer :
233
+ ...
234
+
235
+
236
+ class BcastNLevels (_BcastWithoutNextOperand ):
237
+ """
238
+ A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of
239
+ array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as
240
+ convenient aliases for the common cases.
241
+
242
+ Usage example::
243
+
244
+ container + Bcast2(actx_array)
245
+
246
+ .. note::
247
+
248
+ :mod:`numpy` object arrays do not count against the number of levels.
249
+
250
+ .. automethod:: __init__
251
+ """
252
+ nlevels : int
253
+
254
+ def __init__ (self , nlevels : int , arg : ArrayOrContainer ) -> None :
255
+ if nlevels < 1 :
256
+ raise ValueError ("nlevels is expected to be one or greater." )
257
+
258
+ super ().__init__ (arg )
259
+ object .__setattr__ (self , "nlevels" , nlevels )
260
+
261
+ def _rewrap (self ) -> ArrayOrContainer :
262
+ if self .nlevels == 1 :
263
+ return self .arg
264
+ else :
265
+ return BcastNLevels (self .nlevels - 1 , self .arg )
266
+
267
+
268
+ Bcast1Level = partial (BcastNLevels , 1 )
269
+ Bcast2Levels = partial (BcastNLevels , 2 )
270
+ Bcast3Levels = partial (BcastNLevels , 3 )
271
+
272
+
273
+ class BcastUntilActxArray (_BcastWithNextOperand ):
274
+ """
275
+ A broadcast rule that broadcasts *arg* across array containers until
276
+ the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
277
+ of *actx*, or a :class:`~numbers.Number`.
278
+
279
+ Suggested usage pattern::
280
+
281
+ bcast = functools.partial(BcastUntilActxArray, actx)
282
+
283
+ container + bcast(actx_array)
284
+
285
+ .. automethod:: __init__
286
+ """
287
+ actx : ArrayContext
288
+
289
+ def __init__ (self ,
290
+ actx : ArrayContext ,
291
+ arg : ArrayOrContainer ) -> None :
292
+ super ().__init__ (arg )
293
+ object .__setattr__ (self , "actx" , actx )
294
+
295
+ def _rewrap (self , other_operand : ArrayOrContainer ) -> ArrayOrContainer :
296
+ if isinstance (other_operand , (* self .actx .array_types , Number )):
297
+ return self .arg
298
+ else :
299
+ return self
300
+
301
+
161
302
def with_container_arithmetic (
162
303
* ,
163
304
bcast_number : bool = True ,
@@ -206,6 +347,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
206
347
207
348
Each operator class also includes the "reverse" operators if applicable.
208
349
350
+ .. note::
351
+
352
+ For the generated binary arithmetic operators, if certain types
353
+ should be broadcast over the container (with the container as the
354
+ 'outer' structure) but are not handled in this way by their types,
355
+ you may wrap them in one of the :class:`Bcast` variants to achieve
356
+ the desired semantics.
357
+
209
358
.. note::
210
359
211
360
To generate the code implementing the operators, this function relies on
@@ -238,6 +387,24 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
238
387
#
239
388
# - Broadcast rules are hard to change once established, particularly
240
389
# because one cannot grep for their use.
390
+ #
391
+ # Possible advantages of the "Bcast" broadcast-rule-as-object design:
392
+ #
393
+ # - If one rule does not fit the user's need, they can straightforwardly use
394
+ # another.
395
+ #
396
+ # - It's straightforward to find where certain broadcast rules are used.
397
+ #
398
+ # - The broadcast rule can contain more state. For example, it's now easy
399
+ # for the rule to know what array context should be used to determine
400
+ # actx array types.
401
+ #
402
+ # Possible downsides of the "Bcast" broadcast-rule-as-object design:
403
+ #
404
+ # - User code is a bit more wordy.
405
+ #
406
+ # - Rewrapping has the potential to be costly, especially in
407
+ # _with_next_operand mode.
241
408
242
409
# {{{ handle inputs
243
410
@@ -349,9 +516,8 @@ def wrap(cls: Any) -> Any:
349
516
f"Broadcasting array context array types across { cls } "
350
517
"has been explicitly "
351
518
"enabled. As of 2025, this will stop working. "
352
- "There is no replacement as of right now. "
353
- "See the discussion in "
354
- "https://github.com/inducer/arraycontext/pull/190. "
519
+ "Express these operations using arraycontext.Bcast variants "
520
+ "instead. "
355
521
"To opt out now (and avoid this warning), "
356
522
"pass _bcast_actx_array_type=False. " ,
357
523
DeprecationWarning , stacklevel = 2 )
@@ -360,9 +526,8 @@ def wrap(cls: Any) -> Any:
360
526
f"Broadcasting array context array types across { cls } "
361
527
"has been implicitly "
362
528
"enabled. As of 2025, this will no longer work. "
363
- "There is no replacement as of right now. "
364
- "See the discussion in "
365
- "https://github.com/inducer/arraycontext/pull/190. "
529
+ "Express these operations using arraycontext.Bcast variants "
530
+ "instead. "
366
531
"To opt out now (and avoid this warning), "
367
532
"pass _bcast_actx_array_type=False." ,
368
533
DeprecationWarning , stacklevel = 2 )
@@ -380,7 +545,7 @@ def wrap(cls: Any) -> Any:
380
545
gen (f"""
381
546
from numbers import Number
382
547
import numpy as np
383
- from arraycontext import ArrayContainer
548
+ from arraycontext import ArrayContainer, Bcast
384
549
from warnings import warn
385
550
386
551
def _raise_if_actx_none(actx):
@@ -400,7 +565,8 @@ def is_numpy_array(arg):
400
565
"behavior will change in 2025. If you would like the "
401
566
"broadcasting behavior to stay the same, make sure "
402
567
"to convert the passed numpy array to an "
403
- "object array.",
568
+ "object array, or use arraycontext.Bcast to achieve "
569
+ "the desired broadcasting semantics.",
404
570
DeprecationWarning, stacklevel=3)
405
571
return True
406
572
else:
@@ -492,6 +658,33 @@ def {fname}(arg1):
492
658
cls ._serialize_init_arrays_code ("arg2" ).items ()
493
659
})
494
660
661
+ def get_operand (arg : Union [tuple [str , str ], str ]) -> str :
662
+ if isinstance (arg , tuple ):
663
+ entry , _container = arg
664
+ return entry
665
+ else :
666
+ return arg
667
+
668
+ bcast_init_args_arg1_is_outer_with_rewrap = \
669
+ cls ._deserialize_init_arrays_code ("arg1" , {
670
+ key_arg1 :
671
+ _format_binary_op_str (
672
+ op_str , expr_arg1 ,
673
+ f"arg2._rewrap({ get_operand (expr_arg1 )} )" )
674
+ for key_arg1 , expr_arg1 in
675
+ cls ._serialize_init_arrays_code ("arg1" ).items ()
676
+ })
677
+ bcast_init_args_arg2_is_outer_with_rewrap = \
678
+ cls ._deserialize_init_arrays_code ("arg2" , {
679
+ key_arg2 :
680
+ _format_binary_op_str (
681
+ op_str ,
682
+ f"arg1._rewrap({ get_operand (expr_arg2 )} )" ,
683
+ expr_arg2 )
684
+ for key_arg2 , expr_arg2 in
685
+ cls ._serialize_init_arrays_code ("arg2" ).items ()
686
+ })
687
+
495
688
# {{{ "forward" binary operators
496
689
497
690
gen (f"def { fname } (arg1, arg2):" )
@@ -544,14 +737,19 @@ def {fname}(arg1):
544
737
warn("Broadcasting { cls } over array "
545
738
f"context array type {{type(arg2)}} is deprecated "
546
739
"and will no longer work in 2025. "
547
- "There is no replacement as of right now. "
548
- "See the discussion in "
549
- "https://github.com/inducer/arraycontext/"
550
- "pull/190. ",
740
+ "Use arraycontext.Bcast to achieve the desired "
741
+ "broadcasting semantics.",
551
742
DeprecationWarning, stacklevel=2)
552
743
553
744
return cls({ bcast_init_args_arg1_is_outer } )
554
745
746
+ if isinstance(arg2, Bcast):
747
+ if arg2._with_next_operand:
748
+ return cls({ bcast_init_args_arg1_is_outer_with_rewrap } )
749
+ else:
750
+ arg2 = arg2._rewrap()
751
+ return cls({ bcast_init_args_arg1_is_outer } )
752
+
555
753
return NotImplemented
556
754
""" )
557
755
gen (f"cls.__{ dunder_name } __ = { fname } " )
@@ -595,14 +793,19 @@ def {fname}(arg2, arg1):
595
793
f"context array type {{type(arg1)}} "
596
794
"is deprecated "
597
795
"and will no longer work in 2025."
598
- "There is no replacement as of right now. "
599
- "See the discussion in "
600
- "https://github.com/inducer/arraycontext/"
601
- "pull/190. ",
796
+ "Use arraycontext.Bcast to achieve the "
797
+ "desired broadcasting semantics.",
602
798
DeprecationWarning, stacklevel=2)
603
799
604
800
return cls({ bcast_init_args_arg2_is_outer } )
605
801
802
+ if isinstance(arg1, Bcast):
803
+ if arg1._with_next_operand:
804
+ return cls({ bcast_init_args_arg2_is_outer_with_rewrap } )
805
+ else:
806
+ arg1 = arg1._rewrap()
807
+ return cls({ bcast_init_args_arg2_is_outer } )
808
+
606
809
return NotImplemented
607
810
608
811
cls.__r{ dunder_name } __ = { fname } """ )
0 commit comments