Skip to content

Commit 29d948b

Browse files
authored
Merge pull request #96 from asmeurer/vector_norm-fix
Fix numpy vector_norm(keepdims=True)
2 parents b28a0ea + 0837875 commit 29d948b

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

array_api_compat/common/_linalg.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Literal, Optional, Tuple, Union
66
from ._typing import ndarray
77

8+
import math
9+
810
import numpy as np
911
if np.__version__[0] == "2":
1012
from numpy.lib.array_utils import normalize_axis_tuple
@@ -110,21 +112,22 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]
110112
# on a single dimension.
111113
if axis is None:
112114
# Note: xp.linalg.norm() doesn't handle 0-D arrays
113-
x = x.ravel()
115+
_x = x.ravel()
114116
_axis = 0
115117
elif isinstance(axis, tuple):
116118
# Note: The axis argument supports any number of axes, whereas
117119
# xp.linalg.norm() only supports a single axis for vector norm.
118120
normalized_axis = normalize_axis_tuple(axis, x.ndim)
119121
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
120122
newshape = axis + rest
121-
x = xp.transpose(x, newshape).reshape(
122-
(xp.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest]))
123+
_x = xp.transpose(x, newshape).reshape(
124+
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
123125
_axis = 0
124126
else:
127+
_x = x
125128
_axis = axis
126129

127-
res = xp.linalg.norm(x, axis=_axis, ord=ord)
130+
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
128131

129132
if keepdims:
130133
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks

test_cupy.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ mkdir -p $SCRIPT_DIR/.hypothesis
2626
ln -s $SCRIPT_DIR/.hypothesis .hypothesis
2727

2828
export ARRAY_API_TESTS_MODULE=array_api_compat.cupy
29-
pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt --skips-file $SCRIPT_DIR/cupy-skips.txt "$@"
29+
pytest ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@"

0 commit comments

Comments
 (0)