forked from data-apis/array-api-strict
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_statistical_functions.py
190 lines (161 loc) · 5.52 KB
/
_statistical_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from typing import Any
import numpy as np
from ._array_object import Array
from ._creation_functions import ones, zeros
from ._dtypes import (
DType,
_floating_dtypes,
_np_dtype,
_numeric_dtypes,
_real_floating_dtypes,
_real_numeric_dtypes,
complex64,
float32,
)
from ._flags import get_array_api_strict_flags, requires_api_version
from ._manipulation_functions import concat
@requires_api_version('2023.12')
def cumulative_sum(
x: Array,
/,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
axis = 0
# np.cumsum does not support include_initial
if include_initial:
if axis < 0:
axis += x.ndim
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
return Array._new(np.cumsum(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device)
@requires_api_version('2024.12')
def cumulative_prod(
x: Array,
/,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in cumulative_prod")
if x.ndim == 0:
raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod")
if axis is None:
if x.ndim > 1:
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
axis = 0
# np.cumprod does not support include_initial
if include_initial:
if axis < 0:
axis += x.ndim
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
return Array._new(np.cumprod(x._array, axis=axis, dtype=_np_dtype(dtype)), device=x.device)
def max(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in max")
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims), device=x.device)
def mean(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
allowed_dtypes = (
_floating_dtypes
if get_array_api_strict_flags()['api_version'] > '2023.12'
else _real_floating_dtypes
)
if x.dtype not in allowed_dtypes:
raise TypeError("Only floating-point dtypes are allowed in mean")
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device)
def min(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in min")
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device)
def _np_dtype_sumprod(x: Array, dtype: DType | None) -> np.dtype[Any] | None:
"""In versions prior to 2023.12, sum() and prod() upcast for all
dtypes when dtype=None. For 2023.12, the behavior is the same as in
NumPy (only upcast for integral dtypes).
"""
if dtype is None and get_array_api_strict_flags()['api_version'] < '2023.12':
if x.dtype == float32:
return np.float64 # type: ignore[return-value]
elif x.dtype == complex64:
return np.complex128 # type: ignore[return-value]
return _np_dtype(dtype)
def prod(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
dtype: DType | None = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in prod")
np_dtype = _np_dtype_sumprod(x, dtype)
return Array._new(
np.prod(x._array, dtype=np_dtype, axis=axis, keepdims=keepdims),
device=x.device,
)
def std(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
correction: int | float = 0.0,
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in std")
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)
def sum(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
dtype: DType | None = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
np_dtype = _np_dtype_sumprod(x, dtype)
return Array._new(
np.sum(x._array, axis=axis, dtype=np_dtype, keepdims=keepdims),
device=x.device,
)
def var(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
correction: int | float = 0.0,
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)