Skip to content

Commit d567ad7

Browse files
committed
outer: disallow non-object numpy arrays
1 parent bb955e3 commit d567ad7

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

Diff for: arraycontext/container/traversal.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,7 @@ def outer(a: Any, b: Any) -> Any:
949949
Tweaks the behavior of :func:`numpy.outer` to return a lower-dimensional
950950
object if either/both of *a* and *b* are scalars (whereas :func:`numpy.outer`
951951
always returns a matrix). Here the definition of "scalar" includes
952-
all non-array-container types and any scalar-like array container types
953-
(including non-object numpy arrays).
952+
all non-array-container types and any scalar-like array container types.
954953
955954
If *a* and *b* are both array containers, the result will have the same type
956955
as *a*. If both are array containers and neither is an object array, they must
@@ -968,12 +967,19 @@ def treat_as_scalar(x: Any) -> bool:
968967
# This condition is whether "ndarrays should broadcast inside x".
969968
and NumpyObjectArray not in x.__class__._outer_bcast_types)
970969

970+
a_is_ndarray = isinstance(a, np.ndarray)
971+
b_is_ndarray = isinstance(b, np.ndarray)
972+
973+
if a_is_ndarray and a.dtype != object:
974+
raise TypeError("passing a non-object numpy array is not allowed")
975+
if b_is_ndarray and b.dtype != object:
976+
raise TypeError("passing a non-object numpy array is not allowed")
977+
971978
if treat_as_scalar(a) or treat_as_scalar(b):
972979
return a*b
973-
# After this point, "isinstance(o, ndarray)" means o is an object array.
974-
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
980+
elif a_is_ndarray and b_is_ndarray:
975981
return np.outer(a, b)
976-
elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
982+
elif a_is_ndarray or b_is_ndarray:
977983
return map_array_container(lambda x: outer(x, b), a)
978984
else:
979985
if type(a) is not type(b):

Diff for: test/test_arraycontext.py

-9
Original file line numberDiff line numberDiff line change
@@ -1531,15 +1531,6 @@ def equal(a, b):
15311531
b_bcast_dc_of_dofs.momentum),
15321532
enthalpy=a_bcast_dc_of_dofs.enthalpy*b_bcast_dc_of_dofs.enthalpy))
15331533

1534-
# Non-object numpy arrays should be treated as scalars
1535-
ary_of_floats = np.ones(len(b_bcast_dc_of_dofs.mass))
1536-
assert equal(
1537-
outer(ary_of_floats, b_bcast_dc_of_dofs),
1538-
ary_of_floats*b_bcast_dc_of_dofs)
1539-
assert equal(
1540-
outer(a_bcast_dc_of_dofs, ary_of_floats),
1541-
a_bcast_dc_of_dofs*ary_of_floats)
1542-
15431534
# }}}
15441535

15451536

0 commit comments

Comments
 (0)