Skip to content

Commit 510dc1b

Browse files
committed
with_container_arithmetic: Rename arguments to signal who broadcasts across who
Names suggested by @majosm
1 parent 8b1b795 commit 510dc1b

File tree

2 files changed

+119
-54
lines changed

2 files changed

+119
-54
lines changed

arraycontext/container/arithmetic.py

+113-48
Original file line numberDiff line numberDiff line change
@@ -159,34 +159,40 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet
159159

160160

161161
def with_container_arithmetic(
162-
*,
163-
bcast_number: bool = True,
164-
_bcast_actx_array_type: Optional[bool] = None,
165-
bcast_obj_array: Optional[bool] = None,
166-
bcast_numpy_array: bool = False,
167-
bcast_container_types: Optional[Tuple[type, ...]] = None,
168-
arithmetic: bool = True,
169-
matmul: bool = False,
170-
bitwise: bool = False,
171-
shift: bool = False,
172-
_cls_has_array_context_attr: Optional[bool] = None,
173-
eq_comparison: Optional[bool] = None,
174-
rel_comparison: Optional[bool] = None) -> Callable[[type], type]:
162+
*,
163+
number_bcasts_across: Optional[bool] = None,
164+
bcasts_across_obj_array: Optional[bool] = None,
165+
container_types_bcast_across: Optional[Tuple[type, ...]] = None,
166+
arithmetic: bool = True,
167+
matmul: bool = False,
168+
bitwise: bool = False,
169+
shift: bool = False,
170+
_cls_has_array_context_attr: Optional[bool] = None,
171+
eq_comparison: Optional[bool] = None,
172+
rel_comparison: Optional[bool] = None,
173+
174+
# deprecated:
175+
bcast_number: Optional[bool] = None,
176+
bcast_obj_array: Optional[bool] = None,
177+
bcast_numpy_array: bool = False,
178+
_bcast_actx_array_type: Optional[bool] = None,
179+
bcast_container_types: Optional[Tuple[type, ...]] = None,
180+
) -> Callable[[type], type]:
175181
"""A class decorator that implements built-in operators for array containers
176182
by propagating the operations to the elements of the container.
177183
178-
:arg bcast_number: If *True*, numbers broadcast over the container
184+
:arg number_bcasts_across: If *True*, numbers broadcast over the container
179185
(with the container as the 'outer' structure).
180-
:arg bcast_obj_array: If *True*, this container will be broadcast
186+
:arg bcasts_across_obj_array: If *True*, this container will be broadcast
181187
across :mod:`numpy` object arrays
182188
(with the object array as the 'outer' structure).
183-
Add :class:`numpy.ndarray` to *bcast_container_types* to achieve
189+
Add :class:`numpy.ndarray` to *container_types_bcast_across* to achieve
184190
the 'reverse' broadcasting.
185-
:arg bcast_container_types: A sequence of container types that will broadcast
191+
:arg container_types_bcast_across: A sequence of container types that will broadcast
186192
across this container, with this container as the 'outer' structure.
187193
:class:`numpy.ndarray` is permitted to be part of this sequence to
188-
indicate that object arrays (and *only* object arrays) will be broadcasat.
189-
In this case, *bcast_obj_array* must be *False*.
194+
indicate that object arrays (and *only* object arrays) will be broadcast.
195+
In this case, *bcasts_across_obj_array* must be *False*.
190196
:arg arithmetic: Implement the conventional arithmetic operators, including
191197
``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
192198
:func:`abs`.
@@ -241,8 +247,71 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
241247

242248
# {{{ handle inputs
243249

244-
if bcast_obj_array is None:
245-
raise TypeError("bcast_obj_array must be specified")
250+
if rel_comparison and eq_comparison is None:
251+
eq_comparison = True
252+
253+
if eq_comparison is None:
254+
raise TypeError("eq_comparison must be specified")
255+
256+
# {{{ handle bcast_number
257+
258+
if bcast_number is not None:
259+
if number_bcasts_across is not None:
260+
raise TypeError(
261+
"may specify at most one of 'bcast_number' and "
262+
"'number_bcasts_across'")
263+
264+
warn("'bcast_number' is deprecated and will be unsupported from 2025. "
265+
"Use 'number_bcasts_across', with equivalent meaning.",
266+
DeprecationWarning, stacklevel=2)
267+
number_bcasts_across = bcast_number
268+
else:
269+
if number_bcasts_across is None:
270+
number_bcasts_across = True
271+
272+
del bcast_number
273+
274+
# }}}
275+
276+
# {{{ handle bcast_obj_array
277+
278+
if bcast_obj_array is not None:
279+
if bcasts_across_obj_array is not None:
280+
raise TypeError(
281+
"may specify at most one of 'bcast_obj_array' and "
282+
"'bcasts_across_obj_array'")
283+
284+
warn("'bcast_obj_array' is deprecated and will be unsupported from 2025. "
285+
"Use 'bcasts_across_obj_array', with equivalent meaning.",
286+
DeprecationWarning, stacklevel=2)
287+
bcasts_across_obj_array = bcast_obj_array
288+
else:
289+
if bcasts_across_obj_array is None:
290+
raise TypeError("bcasts_across_obj_array must be specified")
291+
292+
del bcast_obj_array
293+
294+
# }}}
295+
296+
# {{{ handle bcast_container_types
297+
298+
if bcast_container_types is not None:
299+
if container_types_bcast_across is not None:
300+
raise TypeError(
301+
"may specify at most one of 'bcast_container_types' and "
302+
"'container_types_bcast_across'")
303+
304+
warn("'bcast_container_types' is deprecated and will be unsupported from 2025. "
305+
"Use 'container_types_bcast_across', with equivalent meaning.",
306+
DeprecationWarning, stacklevel=2)
307+
container_types_bcast_across = bcast_container_types
308+
else:
309+
if container_types_bcast_across is None:
310+
container_types_bcast_across = ()
311+
312+
del bcast_container_types
313+
314+
# }}}
246315

247316
if rel_comparison is None:
248317
raise TypeError("rel_comparison must be specified")
@@ -255,36 +324,27 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
255324
raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
256325
" cannot be both set.")
257326

258-
if rel_comparison and eq_comparison is None:
259-
eq_comparison = True
260-
261-
if eq_comparison is None:
262-
raise TypeError("eq_comparison must be specified")
263-
264-
if not bcast_obj_array and bcast_numpy_array:
327+
if not bcasts_across_obj_array and bcast_numpy_array:
265328
raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")
266329

267330
if bcast_numpy_array:
268331
def numpy_pred(name: str) -> str:
269332
return f"is_numpy_array({name})"
270-
elif bcast_obj_array:
333+
elif bcasts_across_obj_array:
271334
def numpy_pred(name: str) -> str:
272335
return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
273336
else:
274337
def numpy_pred(name: str) -> str:
275338
return "False" # optimized away
276339

277-
if bcast_container_types is None:
278-
bcast_container_types = ()
279-
280-
if np.ndarray in bcast_container_types and bcast_obj_array:
340+
if np.ndarray in container_types_bcast_across and bcasts_across_obj_array:
281341
raise ValueError("If numpy.ndarray is part of bcast_container_types, "
282342
"bcast_obj_array must be False.")
283343

284344
numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray]
285-
bcast_container_types = tuple(
345+
container_types_bcast_across = tuple(
286346
new_ct
287-
for old_ct in bcast_container_types
347+
for old_ct in container_types_bcast_across
288348
for new_ct in
289349
(numpy_check_types
290350
if old_ct is np.ndarray
@@ -334,7 +394,7 @@ def wrap(cls: Any) -> Any:
334394

335395
if bcast_actx_array_type is None:
336396
if cls_has_array_context_attr:
337-
if bcast_number:
397+
if number_bcasts_across:
338398
bcast_actx_array_type = cls_has_array_context_attr
339399
else:
340400
bcast_actx_array_type = False
@@ -409,14 +469,14 @@ def is_numpy_array(arg):
409469
""")
410470
gen("")
411471

412-
if bcast_container_types:
413-
for i, bct in enumerate(bcast_container_types):
472+
if container_types_bcast_across:
473+
for i, bct in enumerate(container_types_bcast_across):
414474
gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
415475
gen("")
416-
outer_bcast_type_names = tuple(
417-
f"_bctype{i}" for i in range(len(bcast_container_types)))
418-
if bcast_number:
419-
outer_bcast_type_names += ("Number",)
476+
container_type_names_bcast_across = tuple(
477+
f"_bctype{i}" for i in range(len(container_types_bcast_across)))
478+
if number_bcasts_across:
479+
container_type_names_bcast_across += ("Number",)
420480

421481
def same_key(k1: T, k2: T) -> T:
422482
assert k1 == k2
@@ -428,9 +488,14 @@ def tup_str(t: Tuple[str, ...]) -> str:
428488
else:
429489
return "({},)".format(", ".join(t))
430490

431-
gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}")
491+
gen(f"cls._outer_bcast_types = {tup_str(container_type_names_bcast_across)}")
492+
gen("cls._container_types_bcast_across = "
493+
f"{tup_str(container_type_names_bcast_across)}")
494+
432495
gen(f"cls._bcast_numpy_array = {bcast_numpy_array}")
433-
gen(f"cls._bcast_obj_array = {bcast_obj_array}")
496+
497+
gen(f"cls._bcast_obj_array = {bcasts_across_obj_array}")
498+
gen(f"cls._bcasts_across_obj_array = {bcasts_across_obj_array}")
434499
gen("")
435500

436501
# {{{ unary operators
@@ -535,9 +600,9 @@ def {fname}(arg1):
535600
result[i] = {op_str.format("arg1", "arg2[i]")}
536601
return result
537602
538-
if {bool(outer_bcast_type_names)}: # optimized away
603+
if {bool(container_type_names_bcast_across)}: # optimized away
539604
if isinstance(arg2,
540-
{tup_str(outer_bcast_type_names
605+
{tup_str(container_type_names_bcast_across
541606
+ bcast_actx_ary_types)}):
542607
if __debug__:
543608
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
@@ -584,9 +649,9 @@ def {fname}(arg2, arg1):
584649
for i in np.ndindex(arg1.shape):
585650
result[i] = {op_str.format("arg1[i]", "arg2")}
586651
return result
587-
if {bool(outer_bcast_type_names)}: # optimized away
652+
if {bool(container_type_names_bcast_across)}: # optimized away
588653
if isinstance(arg1,
589-
{tup_str(outer_bcast_type_names
654+
{tup_str(container_type_names_bcast_across
590655
+ bcast_actx_ary_types)}):
591656
if __debug__:
592657
if isinstance(arg1,

test/test_arraycontext.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _acf():
117117
# {{{ stand-in DOFArray implementation
118118

119119
@with_container_arithmetic(
120-
bcast_obj_array=True,
120+
bcasts_across_obj_array=True,
121121
bitwise=True,
122122
rel_comparison=True,
123123
_cls_has_array_context_attr=True,
@@ -208,7 +208,7 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type:
208208

209209
# {{{ nested containers
210210

211-
@with_container_arithmetic(bcast_obj_array=False,
211+
@with_container_arithmetic(bcasts_across_obj_array=False,
212212
eq_comparison=False, rel_comparison=False,
213213
_cls_has_array_context_attr=True,
214214
_bcast_actx_array_type=False)
@@ -231,7 +231,7 @@ def array_context(self):
231231

232232

233233
@with_container_arithmetic(
234-
bcast_obj_array=False,
234+
bcasts_across_obj_array=False,
235235
bcast_container_types=(DOFArray, np.ndarray),
236236
matmul=True,
237237
rel_comparison=True,
@@ -1225,7 +1225,7 @@ def test_norm_ord_none(actx_factory, ndim):
12251225

12261226
# {{{ test_actx_compile helpers
12271227

1228-
@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
1228+
@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True)
12291229
@dataclass_array_container
12301230
@dataclass(frozen=True)
12311231
class Velocity2D:
@@ -1355,7 +1355,7 @@ def test_container_equality(actx_factory):
13551355
# {{{ test_no_leaf_array_type_broadcasting
13561356

13571357
@with_container_arithmetic(
1358-
bcast_obj_array=True,
1358+
bcasts_across_obj_array=True,
13591359
rel_comparison=True,
13601360
_cls_has_array_context_attr=True,
13611361
_bcast_actx_array_type=False)
@@ -1459,7 +1459,7 @@ def equal(a, b):
14591459

14601460
# {{{ test_array_container_with_numpy
14611461

1462-
@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
1462+
@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True)
14631463
@dataclass_array_container
14641464
@dataclass(frozen=True)
14651465
class ArrayContainerWithNumpy:

0 commit comments

Comments
 (0)