Skip to content

Commit 8b6dac9

Browse files
committed
Allow broadcasting in cross()
1 parent 397713f commit 8b6dac9

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

array_api_strict/linalg.py

-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
7373
"""
7474
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
7575
raise TypeError('Only numeric dtypes are allowed in cross')
76-
# Note: this is different from np.cross(), which broadcasts
77-
if x1.shape != x2.shape:
78-
raise ValueError('x1 and x2 must have the same shape')
7976
if x1.ndim == 0:
8077
raise ValueError('cross() requires arrays of dimension at least 1')
8178
# Note: this is different from np.cross(), which allows dimension 2

0 commit comments

Comments
 (0)