Skip to content

Commit 0c721a5

Browse files
committed
compare-shape
1 parent 2ac196c commit 0c721a5

File tree

5 files changed

+19
-6
lines changed

5 files changed

+19
-6
lines changed

dpnp/dpnp_iface_linearalgebra.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1113,13 +1113,13 @@ def vdot(a, b):
11131113
if b.size != 1:
11141114
raise ValueError("The second array should be of size one.")
11151115
a_conj = numpy.conj(a)
1116-
return dpnp.multiply(a_conj, b)
1116+
return dpnp.squeeze(dpnp.multiply(a_conj, b))
11171117

11181118
if dpnp.isscalar(b):
11191119
if a.size != 1:
11201120
raise ValueError("The first array should be of size one.")
11211121
a_conj = dpnp.conj(a)
1122-
return dpnp.multiply(a_conj, b)
1122+
return dpnp.squeeze(dpnp.multiply(a_conj, b))
11231123

11241124
if a.ndim == 1 and b.ndim == 1:
11251125
return dpnp_dot(a, b, out=None, conjugate=True)

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ def dpnp_matmul(
977977
result = dpnp.moveaxis(result, (-2, -1), axes_res)
978978
elif len(axes_res) == 1:
979979
result = dpnp.moveaxis(result, (-1,), axes_res)
980-
return dpnp.ascontiguousarray(result)
980+
return result
981981

982982
return dpnp.asarray(result, order=order)
983983

dpnp/tests/helper.py

+6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def assert_dtype_allclose(
3636
3737
"""
3838

39+
if isinstance(numpy_arr, numpy.ndarray):
40+
assert dpnp_arr.shape == numpy_arr.shape
41+
else:
42+
# numpy output is scalar, then dpnp is 0-D array
43+
assert dpnp_arr.shape == ()
44+
3945
list_64bit_types = [numpy.float64, numpy.complex128]
4046
is_inexact = lambda x: hasattr(x, "dtype") and dpnp.issubdtype(
4147
x.dtype, dpnp.inexact

dpnp/tests/test_mathematical.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3030,10 +3030,13 @@ def test_matmul_strided1(self, stride):
30303030
expected = numpy.matmul(a, a)
30313031
assert_dtype_allclose(result, expected)
30323032

3033-
OUT = dpnp.empty(shape, dtype=result.dtype)
3033+
OUT = numpy.empty(shape, dtype=result.dtype)
30343034
out = OUT[slices]
3035-
result = dpnp.matmul(a_dp, a_dp, out=out)
3036-
assert result is out
3035+
iOUT = dpnp.array(OUT)
3036+
iout = iOUT[slices]
3037+
result = dpnp.matmul(a_dp, a_dp, out=iout)
3038+
assert result is iout
3039+
expected = numpy.matmul(a, a, out=out)
30373040
assert_dtype_allclose(result, expected)
30383041

30393042
@pytest.mark.parametrize(

dpnp/tests/test_nanfunctions.py

+4
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def test_allnans(self, dtype, array):
122122

123123
result = getattr(dpnp, self.func)(ia)
124124
expected = getattr(numpy, self.func)(a)
125+
if array.shape == ():
126+
# for "0d" case, dpnp returns 0D array, numpy returns 1D array
127+
# with one element. dpnp result is correct based on Array API
128+
expected = numpy.squeeze(expected)
125129
assert_dtype_allclose(result, expected)
126130

127131
@pytest.mark.parametrize("axis", [None, 0, 1])

0 commit comments

Comments
 (0)