Skip to content

Commit 5660da3

Browse files
authored
fix issue gh-2293 (#2294)
resolves issue #2293
1 parent da2eeba commit 5660da3

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

.github/workflows/array-api-skips.txt

-4
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,5 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
3434
array_api_tests/test_signatures.py::test_func_signature[std]
3535
array_api_tests/test_signatures.py::test_func_signature[var]
3636

37-
# wrong shape is returned
38-
array_api_tests/test_linalg.py::test_vecdot
39-
array_api_tests/test_linalg.py::test_linalg_vecdot
40-
4137
# arrays have different values
4238
array_api_tests/test_linalg.py::test_linalg_tensordot

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,14 @@ def _define_dim_flags(x, axis):
198198
"""
199199
Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
200200
x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
201-
except for one of them), for instance, if x.shape = (1, 1, 1, 2),
202-
then x_is_1D = True
201+
except for dimension at `axis`), for instance, if x.shape = (1, 1, 1, 2),
202+
and axis=-1, then x_is_1D = True.
203203
x_is_2D: `x` is 2D array or inherently 2D (all dimensions are equal to one
204204
except for the last two of them), for instance, if x.shape = (1, 1, 3, 2),
205-
then x_is_2D = True
205+
then x_is_2D = True.
206206
x_base_is_1D: `x` is 1D considering only its last two dimensions, for instance,
207-
if x.shape = (3, 4, 1, 2), then x_base_is_1D = True
207+
if x.shape = (3, 4, 1, 2), then x_base_is_1D = True.
208+
208209
"""
209210

210211
x_shape = x.shape
@@ -326,14 +327,11 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim):
326327
if x1_shape[-1] != x2_shape[-1]:
327328
_shape_error(x1_shape[-1], x2_shape[-1], "vecdot", err_msg=0)
328329

329-
_, x1_is_1D, _ = _define_dim_flags(x1, axis=-1)
330-
_, x2_is_1D, _ = _define_dim_flags(x2, axis=-1)
331-
332330
if x1_ndim == 1 and x2_ndim == 1:
333331
result_shape = ()
334-
elif x1_is_1D:
332+
elif x1_ndim == 1:
335333
result_shape = x2_shape[:-1]
336-
elif x2_is_1D:
334+
elif x2_ndim == 1:
337335
result_shape = x1_shape[:-1]
338336
else: # at least 2D
339337
if x1_ndim != x2_ndim:

dpnp/tests/test_product.py

+2
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,8 @@ def setup_method(self):
10001000
((1, 4, 5), (3, 1, 5)),
10011001
((1, 1, 4, 5), (3, 1, 5)),
10021002
((1, 4, 5), (1, 3, 1, 5)),
1003+
((2, 1), (1, 1, 1)),
1004+
((1, 1, 3), (3,)),
10031005
],
10041006
)
10051007
def test_basic(self, dtype, shape1, shape2):

0 commit comments

Comments
 (0)