Skip to content

Commit ce8ab7c

Browse files
kaushikcfdinducer
authored andcommitted
arithmetic fixes to account for np.ndarray being a leaf array
1 parent 6e30532 commit ce8ab7c

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

arraycontext/container/arithmetic.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -492,16 +492,17 @@ def {fname}(arg1):
492492
bcast_actx_ary_types = ()
493493

494494
gen(f"""
495-
if {bool(outer_bcast_type_names)}: # optimized away
496-
if isinstance(arg2,
497-
{tup_str(outer_bcast_type_names
498-
+ bcast_actx_ary_types)}):
499-
return cls({bcast_same_cls_init_args})
500495
if {numpy_pred("arg2")}:
501496
result = np.empty_like(arg2, dtype=object)
502497
for i in np.ndindex(arg2.shape):
503498
result[i] = {op_str.format("arg1", "arg2[i]")}
504499
return result
500+
501+
if {bool(outer_bcast_type_names)}: # optimized away
502+
if isinstance(arg2,
503+
{tup_str(outer_bcast_type_names
504+
+ bcast_actx_ary_types)}):
505+
return cls({bcast_same_cls_init_args})
505506
return NotImplemented
506507
""")
507508
gen(f"cls.__{dunder_name}__ = {fname}")
@@ -538,16 +539,16 @@ def {fname}(arg1):
538539
def {fname}(arg2, arg1):
539540
# assert other.__cls__ is not cls
540541
541-
if {bool(outer_bcast_type_names)}: # optimized away
542-
if isinstance(arg1,
543-
{tup_str(outer_bcast_type_names
544-
+ bcast_actx_ary_types)}):
545-
return cls({bcast_init_args})
546542
if {numpy_pred("arg1")}:
547543
result = np.empty_like(arg1, dtype=object)
548544
for i in np.ndindex(arg1.shape):
549545
result[i] = {op_str.format("arg1[i]", "arg2")}
550546
return result
547+
if {bool(outer_bcast_type_names)}: # optimized away
548+
if isinstance(arg1,
549+
{tup_str(outer_bcast_type_names
550+
+ bcast_actx_ary_types)}):
551+
return cls({bcast_init_args})
551552
return NotImplemented
552553
553554
cls.__r{dunder_name}__ = {fname}""")

0 commit comments

Comments
 (0)