2
2
from __future__ import annotations
3
3
4
4
5
- """
5
+ __doc__ = """
6
6
.. currentmodule:: arraycontext
7
+
7
8
.. autofunction:: with_container_arithmetic
8
9
"""
9
10
10
- import enum
11
-
12
11
13
12
__copyright__ = """
14
13
Copyright (C) 2020-1 University of Illinois Board of Trustees
34
33
THE SOFTWARE.
35
34
"""
36
35
36
+ import enum
37
37
from typing import Any , Callable , Optional , Tuple , TypeVar , Union
38
+ from warnings import warn
38
39
39
40
import numpy as np
40
41
@@ -99,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
99
100
100
101
101
102
def _format_binary_op_str (op_str : str ,
102
- arg1 : Union [Tuple [str , ... ], str ],
103
- arg2 : Union [Tuple [str , ... ], str ]) -> str :
103
+ arg1 : Union [Tuple [str , str ], str ],
104
+ arg2 : Union [Tuple [str , str ], str ]) -> str :
104
105
if isinstance (arg1 , tuple ) and isinstance (arg2 , tuple ):
105
106
import sys
106
107
if sys .version_info >= (3 , 10 ):
@@ -127,6 +128,36 @@ def _format_binary_op_str(op_str: str,
127
128
return op_str .format (arg1 , arg2 )
128
129
129
130
131
+ class NumpyObjectArrayMetaclass (type ):
132
+ def __instancecheck__ (cls , instance : Any ) -> bool :
133
+ return isinstance (instance , np .ndarray ) and instance .dtype == object
134
+
135
+
136
+ class NumpyObjectArray (metaclass = NumpyObjectArrayMetaclass ):
137
+ pass
138
+
139
+
140
+ class ComplainingNumpyNonObjectArrayMetaclass (type ):
141
+ def __instancecheck__ (cls , instance : Any ) -> bool :
142
+ if isinstance (instance , np .ndarray ) and instance .dtype != object :
143
+ # Example usage site:
144
+ # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149
145
+ # where normal is passed in by test_lfr_flux as a 'custom-made'
146
+ # numpy array of dtype float64.
147
+ warn (
148
+ "Broadcasting container against non-object numpy array. "
149
+ "This was never documented to work and will now stop working in "
150
+ "2025. Convert the array to an object array to preserve the "
151
+ "current semantics." , DeprecationWarning , stacklevel = 3 )
152
+ return True
153
+ else :
154
+ return False
155
+
156
+
157
+ class ComplainingNumpyNonObjectArray (metaclass = ComplainingNumpyNonObjectArrayMetaclass ):
158
+ pass
159
+
160
+
130
161
def with_container_arithmetic (
131
162
* ,
132
163
bcast_number : bool = True ,
@@ -146,22 +177,16 @@ def with_container_arithmetic(
146
177
147
178
:arg bcast_number: If *True*, numbers broadcast over the container
148
179
(with the container as the 'outer' structure).
149
- :arg _bcast_actx_array_type: If *True*, instances of base array types of the
150
- container's array context are broadcasted over the container. Can be
151
- *True* only if the container has *_cls_has_array_context_attr* set.
152
- Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set,
153
- else *False*.
154
- :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over
155
- the container. (with the container as the 'inner' structure)
156
- :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast
157
- over the container. (with the container as the 'inner' structure)
158
- If this is set to *True*, *bcast_obj_array* must also be *True*.
180
+ :arg bcast_obj_array: If *True*, this container will be broadcast
181
+ across :mod:`numpy` object arrays
182
+ (with the object array as the 'outer' structure).
183
+ Add :class:`numpy.ndarray` to *bcast_container_types* to achieve
184
+ the 'reverse' broadcasting.
159
185
:arg bcast_container_types: A sequence of container types that will broadcast
160
- over this container ( with this container as the 'outer' structure) .
186
+ across this container, with this container as the 'outer' structure.
161
187
:class:`numpy.ndarray` is permitted to be part of this sequence to
162
- indicate that, in such broadcasting situations, this container should
163
- be the 'outer' structure. In this case, *bcast_obj_array*
164
- (and consequently *bcast_numpy_array*) must be *False*.
188
+ indicate that object arrays (and *only* object arrays) will be broadcasat.
189
+ In this case, *bcast_obj_array* must be *False*.
165
190
:arg arithmetic: Implement the conventional arithmetic operators, including
166
191
``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
167
192
:func:`abs`.
@@ -203,6 +228,17 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
203
228
should nest "outside" :func:dataclass_array_container`.
204
229
"""
205
230
231
+ # Hard-won design lessons:
232
+ #
233
+ # - Anything that special-cases np.ndarray by type is broken by design because:
234
+ # - np.ndarray is an array context array.
235
+ # - numpy object arrays can be array containers.
236
+ # Using NumpyObjectArray and NumpyNonObjectArray *may* be better?
237
+ # They're new, so there is no operational experience with them.
238
+ #
239
+ # - Broadcast rules are hard to change once established, particularly
240
+ # because one cannot grep for their use.
241
+
206
242
# {{{ handle inputs
207
243
208
244
if bcast_obj_array is None :
@@ -212,9 +248,8 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
212
248
raise TypeError ("rel_comparison must be specified" )
213
249
214
250
if bcast_numpy_array :
215
- from warnings import warn
216
251
warn ("'bcast_numpy_array=True' is deprecated and will be unsupported"
217
- " from December 2021 " , DeprecationWarning , stacklevel = 2 )
252
+ " from 2025. " , DeprecationWarning , stacklevel = 2 )
218
253
219
254
if _bcast_actx_array_type :
220
255
raise ValueError ("'bcast_numpy_array' and '_bcast_actx_array_type'"
@@ -231,7 +266,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
231
266
232
267
if bcast_numpy_array :
233
268
def numpy_pred (name : str ) -> str :
234
- return f"isinstance ({ name } , np.ndarray )"
269
+ return f"is_numpy_array ({ name } )"
235
270
elif bcast_obj_array :
236
271
def numpy_pred (name : str ) -> str :
237
272
return f"isinstance({ name } , np.ndarray) and { name } .dtype.char == 'O'"
@@ -241,12 +276,21 @@ def numpy_pred(name: str) -> str:
241
276
242
277
if bcast_container_types is None :
243
278
bcast_container_types = ()
244
- bcast_container_types_count = len (bcast_container_types )
245
279
246
280
if np .ndarray in bcast_container_types and bcast_obj_array :
247
281
raise ValueError ("If numpy.ndarray is part of bcast_container_types, "
248
282
"bcast_obj_array must be False." )
249
283
284
+ numpy_check_types : list [type ] = [NumpyObjectArray , ComplainingNumpyNonObjectArray ]
285
+ bcast_container_types = tuple (
286
+ new_ct
287
+ for old_ct in bcast_container_types
288
+ for new_ct in
289
+ (numpy_check_types
290
+ if old_ct is np .ndarray
291
+ else [old_ct ])
292
+ )
293
+
250
294
desired_op_classes = set ()
251
295
if arithmetic :
252
296
desired_op_classes .add (_OpClass .ARITHMETIC )
@@ -264,19 +308,24 @@ def numpy_pred(name: str) -> str:
264
308
# }}}
265
309
266
310
def wrap (cls : Any ) -> Any :
267
- cls_has_array_context_attr : bool | None = \
268
- _cls_has_array_context_attr
269
- bcast_actx_array_type : bool | None = \
270
- _bcast_actx_array_type
311
+ if not hasattr (cls , "__array_ufunc__" ):
312
+ warn (f"{ cls } does not have __array_ufunc__ set. "
313
+ "This will cause numpy to attempt broadcasting, in a way that "
314
+ "is likely undesired. "
315
+ f"To avoid this, set __array_ufunc__ = None in { cls } ." ,
316
+ stacklevel = 2 )
317
+
318
+ cls_has_array_context_attr : bool | None = _cls_has_array_context_attr
319
+ bcast_actx_array_type : bool | None = _bcast_actx_array_type
271
320
272
321
if cls_has_array_context_attr is None :
273
322
if hasattr (cls , "array_context" ):
274
323
raise TypeError (
275
324
f"{ cls } has an 'array_context' attribute, but it does not "
276
325
"set '_cls_has_array_context_attr' to *True* when calling "
277
326
"with_container_arithmetic. This is being interpreted "
278
- "as 'array_context' being permitted to fail with an exception, "
279
- "which is no longer allowed. "
327
+ "as '. array_context' being permitted to fail "
328
+ "with an exception, which is no longer allowed. "
280
329
f"If { cls .__name__ } .array_context will not fail, pass "
281
330
"'_cls_has_array_context_attr=True'. "
282
331
"If you do not want container arithmetic to make "
@@ -294,6 +343,30 @@ def wrap(cls: Any) -> Any:
294
343
raise TypeError ("_bcast_actx_array_type can be True only if "
295
344
"_cls_has_array_context_attr is set." )
296
345
346
+ if bcast_actx_array_type :
347
+ if _bcast_actx_array_type :
348
+ warn (
349
+ f"Broadcasting array context array types across { cls } "
350
+ "has been explicitly "
351
+ "enabled. As of 2025, this will stop working. "
352
+ "There is no replacement as of right now. "
353
+ "See the discussion in "
354
+ "https://github.com/inducer/arraycontext/pull/190. "
355
+ "To opt out now (and avoid this warning), "
356
+ "pass _bcast_actx_array_type=False. " ,
357
+ DeprecationWarning , stacklevel = 2 )
358
+ else :
359
+ warn (
360
+ f"Broadcasting array context array types across { cls } "
361
+ "has been implicitly "
362
+ "enabled. As of 2025, this will no longer work. "
363
+ "There is no replacement as of right now. "
364
+ "See the discussion in "
365
+ "https://github.com/inducer/arraycontext/pull/190. "
366
+ "To opt out now (and avoid this warning), "
367
+ "pass _bcast_actx_array_type=False." ,
368
+ DeprecationWarning , stacklevel = 2 )
369
+
297
370
if (not hasattr (cls , "_serialize_init_arrays_code" )
298
371
or not hasattr (cls , "_deserialize_init_arrays_code" )):
299
372
raise TypeError (f"class '{ cls .__name__ } ' must provide serialization "
@@ -304,7 +377,7 @@ def wrap(cls: Any) -> Any:
304
377
305
378
from pytools .codegen import CodeGenerator , Indentation
306
379
gen = CodeGenerator ()
307
- gen ("""
380
+ gen (f """
308
381
from numbers import Number
309
382
import numpy as np
310
383
from arraycontext import ArrayContainer
@@ -315,6 +388,24 @@ def _raise_if_actx_none(actx):
315
388
raise ValueError("array containers with frozen arrays "
316
389
"cannot be operated upon")
317
390
return actx
391
+
392
+ def is_numpy_array(arg):
393
+ if isinstance(arg, np.ndarray):
394
+ if arg.dtype != "O":
395
+ warn("Operand is a non-object numpy array, "
396
+ "and the broadcasting behavior of this array container "
397
+ "({ cls } ) "
398
+ "is influenced by this because of its use of "
399
+ "the deprecated bcast_numpy_array. This broadcasting "
400
+ "behavior will change in 2025. If you would like the "
401
+ "broadcasting behavior to stay the same, make sure "
402
+ "to convert the passed numpy array to an "
403
+ "object array.",
404
+ DeprecationWarning, stacklevel=3)
405
+ return True
406
+ else:
407
+ return False
408
+
318
409
""" )
319
410
gen ("" )
320
411
@@ -323,7 +414,7 @@ def _raise_if_actx_none(actx):
323
414
gen (f"from { bct .__module__ } import { bct .__qualname__ } as _bctype{ i } " )
324
415
gen ("" )
325
416
outer_bcast_type_names = tuple ([
326
- f"_bctype{ i } " for i in range (bcast_container_types_count )
417
+ f"_bctype{ i } " for i in range (len ( bcast_container_types ) )
327
418
])
328
419
if bcast_number :
329
420
outer_bcast_type_names += ("Number" ,)
@@ -384,20 +475,25 @@ def {fname}(arg1):
384
475
385
476
continue
386
477
387
- # {{{ "forward" binary operators
388
-
389
478
zip_init_args = cls ._deserialize_init_arrays_code ("arg1" , {
390
479
same_key (key_arg1 , key_arg2 ):
391
480
_format_binary_op_str (op_str , expr_arg1 , expr_arg2 )
392
481
for (key_arg1 , expr_arg1 ), (key_arg2 , expr_arg2 ) in zip (
393
482
cls ._serialize_init_arrays_code ("arg1" ).items (),
394
483
cls ._serialize_init_arrays_code ("arg2" ).items ())
395
484
})
396
- bcast_same_cls_init_args = cls ._deserialize_init_arrays_code ("arg1" , {
485
+ bcast_init_args_arg1_is_outer = cls ._deserialize_init_arrays_code ("arg1" , {
397
486
key_arg1 : _format_binary_op_str (op_str , expr_arg1 , "arg2" )
398
487
for key_arg1 , expr_arg1 in
399
488
cls ._serialize_init_arrays_code ("arg1" ).items ()
400
489
})
490
+ bcast_init_args_arg2_is_outer = cls ._deserialize_init_arrays_code ("arg2" , {
491
+ key_arg2 : _format_binary_op_str (op_str , "arg1" , expr_arg2 )
492
+ for key_arg2 , expr_arg2 in
493
+ cls ._serialize_init_arrays_code ("arg2" ).items ()
494
+ })
495
+
496
+ # {{{ "forward" binary operators
401
497
402
498
gen (f"def { fname } (arg1, arg2):" )
403
499
with Indentation (gen ):
@@ -424,7 +520,7 @@ def {fname}(arg1):
424
520
425
521
if bcast_actx_array_type :
426
522
if __debug__ :
427
- bcast_actx_ary_types = (
523
+ bcast_actx_ary_types : tuple [ str , ...] = (
428
524
"*_raise_if_actx_none("
429
525
"arg1.array_context).array_types" ,)
430
526
else :
@@ -444,7 +540,19 @@ def {fname}(arg1):
444
540
if isinstance(arg2,
445
541
{ tup_str (outer_bcast_type_names
446
542
+ bcast_actx_ary_types )} ):
447
- return cls({ bcast_same_cls_init_args } )
543
+ if __debug__:
544
+ if isinstance(arg2, { tup_str (bcast_actx_ary_types )} ):
545
+ warn("Broadcasting { cls } over array "
546
+ f"context array type {{type(arg2)}} is deprecated "
547
+ "and will no longer work in 2025. "
548
+ "There is no replacement as of right now. "
549
+ "See the discussion in "
550
+ "https://github.com/inducer/arraycontext/"
551
+ "pull/190. ",
552
+ DeprecationWarning, stacklevel=2)
553
+
554
+ return cls({ bcast_init_args_arg1_is_outer } )
555
+
448
556
return NotImplemented
449
557
""" )
450
558
gen (f"cls.__{ dunder_name } __ = { fname } " )
@@ -456,12 +564,6 @@ def {fname}(arg1):
456
564
457
565
if reversible :
458
566
fname = f"_{ cls .__name__ .lower ()} _r{ dunder_name } "
459
- bcast_init_args = cls ._deserialize_init_arrays_code ("arg2" , {
460
- key_arg2 : _format_binary_op_str (
461
- op_str , "arg1" , expr_arg2 )
462
- for key_arg2 , expr_arg2 in
463
- cls ._serialize_init_arrays_code ("arg2" ).items ()
464
- })
465
567
466
568
if bcast_actx_array_type :
467
569
if __debug__ :
@@ -487,7 +589,21 @@ def {fname}(arg2, arg1):
487
589
if isinstance(arg1,
488
590
{ tup_str (outer_bcast_type_names
489
591
+ bcast_actx_ary_types )} ):
490
- return cls({ bcast_init_args } )
592
+ if __debug__:
593
+ if isinstance(arg1,
594
+ { tup_str (bcast_actx_ary_types )} ):
595
+ warn("Broadcasting { cls } over array "
596
+ f"context array type {{type(arg1)}} "
597
+ "is deprecated "
598
+ "and will no longer work in 2025."
599
+ "There is no replacement as of right now. "
600
+ "See the discussion in "
601
+ "https://github.com/inducer/arraycontext/"
602
+ "pull/190. ",
603
+ DeprecationWarning, stacklevel=2)
604
+
605
+ return cls({ bcast_init_args_arg2_is_outer } )
606
+
491
607
return NotImplemented
492
608
493
609
cls.__r{ dunder_name } __ = { fname } """ )
0 commit comments