@@ -198,13 +198,14 @@ def _define_dim_flags(x, axis):
198
198
"""
199
199
Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
200
200
x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
201
- except for one of them ), for instance, if x.shape = (1, 1, 1, 2),
202
- then x_is_1D = True
201
+ except for dimension at `axis` ), for instance, if x.shape = (1, 1, 1, 2),
202
+ and axis=-1, then x_is_1D = True.
203
203
x_is_2D: `x` is 2D array or inherently 2D (all dimensions are equal to one
204
204
except for the last two of them), for instance, if x.shape = (1, 1, 3, 2),
205
- then x_is_2D = True
205
+ then x_is_2D = True.
206
206
x_base_is_1D: `x` is 1D considering only its last two dimensions, for instance,
207
- if x.shape = (3, 4, 1, 2), then x_base_is_1D = True
207
+ if x.shape = (3, 4, 1, 2), then x_base_is_1D = True.
208
+
208
209
"""
209
210
210
211
x_shape = x .shape
@@ -326,14 +327,11 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim):
326
327
if x1_shape [- 1 ] != x2_shape [- 1 ]:
327
328
_shape_error (x1_shape [- 1 ], x2_shape [- 1 ], "vecdot" , err_msg = 0 )
328
329
329
- _ , x1_is_1D , _ = _define_dim_flags (x1 , axis = - 1 )
330
- _ , x2_is_1D , _ = _define_dim_flags (x2 , axis = - 1 )
331
-
332
330
if x1_ndim == 1 and x2_ndim == 1 :
333
331
result_shape = ()
334
- elif x1_is_1D :
332
+ elif x1_ndim == 1 :
335
333
result_shape = x2_shape [:- 1 ]
336
- elif x2_is_1D :
334
+ elif x2_ndim == 1 :
337
335
result_shape = x1_shape [:- 1 ]
338
336
else : # at least 2D
339
337
if x1_ndim != x2_ndim :
0 commit comments