forked from data-apis/array-api-strict
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_statistical_functions.py
167 lines (144 loc) · 5.2 KB
/
_statistical_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from __future__ import annotations
from ._dtypes import (
_real_floating_dtypes,
_real_numeric_dtypes,
_floating_dtypes,
_numeric_dtypes,
)
from ._array_object import Array
from ._dtypes import float32, complex64
from ._flags import requires_api_version, get_array_api_strict_flags
from ._creation_functions import zeros
from ._manipulation_functions import concat
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Tuple, Union
from ._typing import Dtype
import numpy as np
@requires_api_version('2023.12')
def cumulative_sum(
x: Array,
/,
*,
axis: Optional[int] = None,
dtype: Optional[Dtype] = None,
include_initial: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
dt = x.dtype if dtype is None else dtype
if dtype is not None:
dtype = dtype._np_dtype
# TODO: The standard is not clear about what should happen when x.ndim == 0.
if axis is None:
if x.ndim > 1:
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
axis = 0
# np.cumsum does not support include_initial
if include_initial:
if axis < 0:
axis += x.ndim
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
def max(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in max")
return Array._new(np.max(x._array, axis=axis, keepdims=keepdims), device=x.device)
def mean(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
if get_array_api_strict_flags()['api_version'] > '2023.12':
allowed_dtypes = _floating_dtypes
else:
allowed_dtypes = _real_floating_dtypes
if x.dtype not in allowed_dtypes:
raise TypeError("Only floating-point dtypes are allowed in mean")
return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims), device=x.device)
def min(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in min")
return Array._new(np.min(x._array, axis=axis, keepdims=keepdims), device=x.device)
def prod(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in prod")
if dtype is None:
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
# NumPy (only upcast for integral dtypes).
if get_array_api_strict_flags()['api_version'] < '2023.12':
if x.dtype == float32:
dtype = np.float64
elif x.dtype == complex64:
dtype = np.complex128
else:
dtype = dtype._np_dtype
return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims), device=x.device)
def std(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0,
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in std")
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)
def sum(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
if dtype is None:
# Note: In versions prior to 2023.12, sum() and prod() upcast for all
# dtypes when dtype=None. For 2023.12, the behavior is the same as in
# NumPy (only upcast for integral dtypes).
if get_array_api_strict_flags()['api_version'] < '2023.12':
if x.dtype == float32:
dtype = np.float64
elif x.dtype == complex64:
dtype = np.complex128
else:
dtype = dtype._np_dtype
return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims), device=x.device)
def var(
x: Array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0,
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)