22from  __future__ import  annotations 
33
44
5- """ 
5+ __doc__   =   """ 
66.. currentmodule:: arraycontext 
7+ 
78.. autofunction:: with_container_arithmetic 
89""" 
910
10- import  enum 
11- 
1211
1312__copyright__  =  """ 
1413Copyright (C) 2020-1 University of Illinois Board of Trustees 
3433THE SOFTWARE. 
3534""" 
3635
36+ import  enum 
3737from  typing  import  Any , Callable , Optional , Tuple , TypeVar , Union 
38+ from  warnings  import  warn 
3839
3940import  numpy  as  np 
4041
@@ -99,8 +100,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str:
99100
100101
101102def  _format_binary_op_str (op_str : str ,
102-         arg1 : Union [Tuple [str , ... ], str ],
103-         arg2 : Union [Tuple [str , ... ], str ]) ->  str :
103+         arg1 : Union [Tuple [str , str ], str ],
104+         arg2 : Union [Tuple [str , str ], str ]) ->  str :
104105    if  isinstance (arg1 , tuple ) and  isinstance (arg2 , tuple ):
105106        import  sys 
106107        if  sys .version_info  >=  (3 , 10 ):
@@ -127,6 +128,36 @@ def _format_binary_op_str(op_str: str,
127128        return  op_str .format (arg1 , arg2 )
128129
129130
131+ class  NumpyObjectArrayMetaclass (type ):
132+     def  __instancecheck__ (cls , instance : Any ) ->  bool :
133+         return  isinstance (instance , np .ndarray ) and  instance .dtype  ==  object 
134+ 
135+ 
136+ class  NumpyObjectArray (metaclass = NumpyObjectArrayMetaclass ):
137+     pass 
138+ 
139+ 
140+ class  ComplainingNumpyNonObjectArrayMetaclass (type ):
141+     def  __instancecheck__ (cls , instance : Any ) ->  bool :
142+         if  isinstance (instance , np .ndarray ) and  instance .dtype  !=  object :
143+             # Example usage site: 
144+             # https://github.com/illinois-ceesd/mirgecom/blob/f5d0d97c41e8c8a05546b1d1a6a2979ec8ea3554/mirgecom/inviscid.py#L148-L149 
145+             # where normal is passed in by test_lfr_flux as a 'custom-made' 
146+             # numpy array of dtype float64. 
147+             warn (
148+                  "Broadcasting container against non-object numpy array. " 
149+                  "This was never documented to work and will now stop working in " 
150+                  "2025. Convert the array to an object array to preserve the " 
151+                  "current semantics." , DeprecationWarning , stacklevel = 3 )
152+             return  True 
153+         else :
154+             return  False 
155+ 
156+ 
157+ class  ComplainingNumpyNonObjectArray (metaclass = ComplainingNumpyNonObjectArrayMetaclass ):
158+     pass 
159+ 
160+ 
130161def  with_container_arithmetic (
131162        * ,
132163        bcast_number : bool  =  True ,
@@ -146,22 +177,16 @@ def with_container_arithmetic(
146177
147178    :arg bcast_number: If *True*, numbers broadcast over the container 
148179        (with the container as the 'outer' structure). 
149-     :arg _bcast_actx_array_type: If *True*, instances of base array types of the 
150-         container's array context are broadcasted over the container. Can be 
151-         *True* only if the container has *_cls_has_array_context_attr* set. 
152-         Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, 
153-         else *False*. 
154-     :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over 
155-         the container.  (with the container as the 'inner' structure) 
156-     :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast 
157-         over the container.  (with the container as the 'inner' structure) 
158-         If this is set to *True*, *bcast_obj_array* must also be *True*. 
180+     :arg bcast_obj_array: If *True*, this container will be broadcast 
181+         across :mod:`numpy` object arrays 
182+         (with the object array as the 'outer' structure). 
183+         Add :class:`numpy.ndarray` to *bcast_container_types* to achieve 
184+         the 'reverse' broadcasting. 
159185    :arg bcast_container_types: A sequence of container types that will broadcast 
160-         over  this container ( with this container as the 'outer' structure) . 
186+         across  this container,  with this container as the 'outer' structure. 
161187        :class:`numpy.ndarray` is permitted to be part of this sequence to 
162-         indicate that, in such broadcasting situations, this container should 
163-         be the 'outer' structure. In this case, *bcast_obj_array* 
164-         (and consequently *bcast_numpy_array*) must be *False*. 
188+         indicate that object arrays (and *only* object arrays) will be broadcasat. 
189+         In this case, *bcast_obj_array* must be *False*. 
165190    :arg arithmetic: Implement the conventional arithmetic operators, including 
166191        ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as 
167192        :func:`abs`. 
@@ -203,6 +228,17 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
203228    should nest "outside" :func:dataclass_array_container`. 
204229    """ 
205230
231+     # Hard-won design lessons: 
232+     # 
233+     # - Anything that special-cases np.ndarray by type is broken by design because: 
234+     #   - np.ndarray is an array context array. 
235+     #   - numpy object arrays can be array containers. 
236+     #   Using NumpyObjectArray and NumpyNonObjectArray *may* be better? 
237+     #   They're new, so there is no operational experience with them. 
238+     # 
239+     # - Broadcast rules are hard to change once established, particularly 
240+     #   because one cannot grep for their use. 
241+ 
206242    # {{{ handle inputs 
207243
208244    if  bcast_obj_array  is  None :
@@ -212,9 +248,8 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
212248        raise  TypeError ("rel_comparison must be specified" )
213249
214250    if  bcast_numpy_array :
215-         from  warnings  import  warn 
216251        warn ("'bcast_numpy_array=True' is deprecated and will be unsupported" 
217-              " from December 2021 " , DeprecationWarning , stacklevel = 2 )
252+              " from 2025. " , DeprecationWarning , stacklevel = 2 )
218253
219254        if  _bcast_actx_array_type :
220255            raise  ValueError ("'bcast_numpy_array' and '_bcast_actx_array_type'" 
@@ -231,7 +266,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
231266
232267    if  bcast_numpy_array :
233268        def  numpy_pred (name : str ) ->  str :
234-             return  f"isinstance ({ name } , np.ndarray )" 
269+             return  f"is_numpy_array ({ name }  )" 
235270    elif  bcast_obj_array :
236271        def  numpy_pred (name : str ) ->  str :
237272            return  f"isinstance({ name }  , np.ndarray) and { name }  .dtype.char == 'O'" 
@@ -241,12 +276,21 @@ def numpy_pred(name: str) -> str:
241276
242277    if  bcast_container_types  is  None :
243278        bcast_container_types  =  ()
244-     bcast_container_types_count  =  len (bcast_container_types )
245279
246280    if  np .ndarray  in  bcast_container_types  and  bcast_obj_array :
247281        raise  ValueError ("If numpy.ndarray is part of bcast_container_types, " 
248282                "bcast_obj_array must be False." )
249283
284+     numpy_check_types : list [type ] =  [NumpyObjectArray , ComplainingNumpyNonObjectArray ]
285+     bcast_container_types  =  tuple (
286+         new_ct 
287+         for  old_ct  in  bcast_container_types 
288+         for  new_ct  in 
289+         (numpy_check_types 
290+         if  old_ct  is  np .ndarray 
291+         else  [old_ct ])
292+     )
293+ 
250294    desired_op_classes  =  set ()
251295    if  arithmetic :
252296        desired_op_classes .add (_OpClass .ARITHMETIC )
@@ -264,19 +308,24 @@ def numpy_pred(name: str) -> str:
264308    # }}} 
265309
266310    def  wrap (cls : Any ) ->  Any :
267-         cls_has_array_context_attr : bool  |  None  =  \
268-                 _cls_has_array_context_attr 
269-         bcast_actx_array_type : bool  |  None  =  \
270-                 _bcast_actx_array_type 
311+         if  not  hasattr (cls , "__array_ufunc__" ):
312+             warn (f"{ cls }   does not have __array_ufunc__ set. " 
313+                  "This will cause numpy to attempt broadcasting, in a way that " 
314+                  "is likely undesired. " 
315+                  f"To avoid this, set __array_ufunc__ = None in { cls }  ." ,
316+                  stacklevel = 2 )
317+ 
318+         cls_has_array_context_attr : bool  |  None  =  _cls_has_array_context_attr 
319+         bcast_actx_array_type : bool  |  None  =  _bcast_actx_array_type 
271320
272321        if  cls_has_array_context_attr  is  None :
273322            if  hasattr (cls , "array_context" ):
274323                raise  TypeError (
275324                        f"{ cls }   has an 'array_context' attribute, but it does not " 
276325                        "set '_cls_has_array_context_attr' to *True* when calling " 
277326                        "with_container_arithmetic. This is being interpreted " 
278-                         "as 'array_context' being permitted to fail with an exception,  " 
279-                         "which is no longer allowed. " 
327+                         "as '. array_context' being permitted to fail " 
328+                         "with an exception,  which is no longer allowed. " 
280329                        f"If { cls .__name__ }  .array_context will not fail, pass " 
281330                        "'_cls_has_array_context_attr=True'. " 
282331                        "If you do not want container arithmetic to make " 
@@ -294,6 +343,30 @@ def wrap(cls: Any) -> Any:
294343                raise  TypeError ("_bcast_actx_array_type can be True only if " 
295344                                "_cls_has_array_context_attr is set." )
296345
346+         if  bcast_actx_array_type :
347+             if  _bcast_actx_array_type :
348+                 warn (
349+                     f"Broadcasting array context array types across { cls }   " 
350+                     "has been explicitly " 
351+                     "enabled. As of 2025, this will stop working. " 
352+                     "There is no replacement as of right now. " 
353+                     "See the discussion in " 
354+                     "https://github.com/inducer/arraycontext/pull/190. " 
355+                     "To opt out now (and avoid this warning), " 
356+                     "pass _bcast_actx_array_type=False. " ,
357+                     DeprecationWarning , stacklevel = 2 )
358+             else :
359+                 warn (
360+                     f"Broadcasting array context array types across { cls }   " 
361+                     "has been implicitly " 
362+                     "enabled. As of 2025, this will no longer work. " 
363+                     "There is no replacement as of right now. " 
364+                     "See the discussion in " 
365+                     "https://github.com/inducer/arraycontext/pull/190. " 
366+                     "To opt out now (and avoid this warning), " 
367+                     "pass _bcast_actx_array_type=False." ,
368+                     DeprecationWarning , stacklevel = 2 )
369+ 
297370        if  (not  hasattr (cls , "_serialize_init_arrays_code" )
298371                or  not  hasattr (cls , "_deserialize_init_arrays_code" )):
299372            raise  TypeError (f"class '{ cls .__name__ }  ' must provide serialization " 
@@ -304,7 +377,7 @@ def wrap(cls: Any) -> Any:
304377
305378        from  pytools .codegen  import  CodeGenerator , Indentation 
306379        gen  =  CodeGenerator ()
307-         gen (""" 
380+         gen (f """
308381            from numbers import Number 
309382            import numpy as np 
310383            from arraycontext import ArrayContainer 
@@ -315,6 +388,24 @@ def _raise_if_actx_none(actx):
315388                    raise ValueError("array containers with frozen arrays " 
316389                        "cannot be operated upon") 
317390                return actx 
391+ 
392+             def is_numpy_array(arg): 
393+                 if isinstance(arg, np.ndarray): 
394+                     if arg.dtype != "O": 
395+                         warn("Operand is a non-object numpy array, " 
396+                             "and the broadcasting behavior of this array container " 
397+                             "({ cls }  ) " 
398+                             "is influenced by this because of its use of " 
399+                             "the deprecated bcast_numpy_array. This broadcasting " 
400+                             "behavior will change in 2025. If you would like the " 
401+                             "broadcasting behavior to stay the same, make sure " 
402+                             "to convert the passed numpy array to an " 
403+                             "object array.", 
404+                             DeprecationWarning, stacklevel=3) 
405+                     return True 
406+                 else: 
407+                     return False 
408+ 
318409            """ )
319410        gen ("" )
320411
@@ -323,7 +414,7 @@ def _raise_if_actx_none(actx):
323414                gen (f"from { bct .__module__ }   import { bct .__qualname__ }   as _bctype{ i }  " )
324415            gen ("" )
325416        outer_bcast_type_names  =  tuple ([
326-                 f"_bctype{ i }  "  for  i  in  range (bcast_container_types_count )
417+                 f"_bctype{ i }  "  for  i  in  range (len ( bcast_container_types ) )
327418                ])
328419        if  bcast_number :
329420            outer_bcast_type_names  +=  ("Number" ,)
@@ -384,20 +475,25 @@ def {fname}(arg1):
384475
385476                continue 
386477
387-             # {{{ "forward" binary operators 
388- 
389478            zip_init_args  =  cls ._deserialize_init_arrays_code ("arg1" , {
390479                    same_key (key_arg1 , key_arg2 ):
391480                    _format_binary_op_str (op_str , expr_arg1 , expr_arg2 )
392481                    for  (key_arg1 , expr_arg1 ), (key_arg2 , expr_arg2 ) in  zip (
393482                        cls ._serialize_init_arrays_code ("arg1" ).items (),
394483                        cls ._serialize_init_arrays_code ("arg2" ).items ())
395484                    })
396-             bcast_same_cls_init_args  =  cls ._deserialize_init_arrays_code ("arg1" , {
485+             bcast_init_args_arg1_is_outer  =  cls ._deserialize_init_arrays_code ("arg1" , {
397486                    key_arg1 : _format_binary_op_str (op_str , expr_arg1 , "arg2" )
398487                    for  key_arg1 , expr_arg1  in 
399488                    cls ._serialize_init_arrays_code ("arg1" ).items ()
400489                    })
490+             bcast_init_args_arg2_is_outer  =  cls ._deserialize_init_arrays_code ("arg2" , {
491+                     key_arg2 : _format_binary_op_str (op_str , "arg1" , expr_arg2 )
492+                     for  key_arg2 , expr_arg2  in 
493+                     cls ._serialize_init_arrays_code ("arg2" ).items ()
494+                     })
495+ 
496+             # {{{ "forward" binary operators 
401497
402498            gen (f"def { fname }  (arg1, arg2):" )
403499            with  Indentation (gen ):
@@ -424,7 +520,7 @@ def {fname}(arg1):
424520
425521                if  bcast_actx_array_type :
426522                    if  __debug__ :
427-                         bcast_actx_ary_types  =  (
523+                         bcast_actx_ary_types :  tuple [ str , ...]  =  (
428524                            "*_raise_if_actx_none(" 
429525                            "arg1.array_context).array_types" ,)
430526                    else :
@@ -444,7 +540,19 @@ def {fname}(arg1):
444540                    if isinstance(arg2, 
445541                                  { tup_str (outer_bcast_type_names  
446542                                           +  bcast_actx_ary_types )}  ):
447-                         return cls({ bcast_same_cls_init_args }  ) 
543+                         if __debug__: 
544+                             if isinstance(arg2, { tup_str (bcast_actx_ary_types )}  ): 
545+                                 warn("Broadcasting { cls }   over array " 
546+                                     f"context array type {{type(arg2)}} is deprecated " 
547+                                     "and will no longer work in 2025. " 
548+                                     "There is no replacement as of right now. " 
549+                                     "See the discussion in " 
550+                                     "https://github.com/inducer/arraycontext/" 
551+                                     "pull/190. ", 
552+                                     DeprecationWarning, stacklevel=2) 
553+ 
554+                         return cls({ bcast_init_args_arg1_is_outer }  ) 
555+ 
448556                return NotImplemented 
449557                """ )
450558            gen (f"cls.__{ dunder_name }  __ = { fname }  " )
@@ -456,12 +564,6 @@ def {fname}(arg1):
456564
457565            if  reversible :
458566                fname  =  f"_{ cls .__name__ .lower ()}  _r{ dunder_name }  " 
459-                 bcast_init_args  =  cls ._deserialize_init_arrays_code ("arg2" , {
460-                         key_arg2 : _format_binary_op_str (
461-                             op_str , "arg1" , expr_arg2 )
462-                         for  key_arg2 , expr_arg2  in 
463-                         cls ._serialize_init_arrays_code ("arg2" ).items ()
464-                         })
465567
466568                if  bcast_actx_array_type :
467569                    if  __debug__ :
@@ -487,7 +589,21 @@ def {fname}(arg2, arg1):
487589                            if isinstance(arg1, 
488590                                          { tup_str (outer_bcast_type_names  
489591                                                   +  bcast_actx_ary_types )}  ):
490-                                 return cls({ bcast_init_args }  ) 
592+                                 if __debug__: 
593+                                     if isinstance(arg1, 
594+                                             { tup_str (bcast_actx_ary_types )}  ): 
595+                                         warn("Broadcasting { cls }   over array " 
596+                                             f"context array type {{type(arg1)}} " 
597+                                             "is deprecated " 
598+                                             "and will no longer work in 2025." 
599+                                             "There is no replacement as of right now. " 
600+                                             "See the discussion in " 
601+                                             "https://github.com/inducer/arraycontext/" 
602+                                             "pull/190. ", 
603+                                             DeprecationWarning, stacklevel=2) 
604+ 
605+                                 return cls({ bcast_init_args_arg2_is_outer }  ) 
606+ 
491607                        return NotImplemented 
492608
493609                    cls.__r{ dunder_name }  __ = { fname }  """ )
0 commit comments