forked from data-apis/array-api-tests
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharray_helpers.py
347 lines (275 loc) · 11.4 KB
/
array_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
from ._array_module import (isnan, all, any, equal, not_equal, logical_and,
logical_or, isfinite, greater, less, less_equal,
zeros, ones, full, bool, int8, int16, int32,
int64, uint8, uint16, uint32, uint64, float32,
float64, nan, inf, pi, remainder, divide, isinf,
negative, asarray)
# These are exported here so that they can be included in the special cases
# tests from this file.
from ._array_module import logical_not, subtract, floor, ceil, where
from . import dtype_helpers as dh
from ndindex import iter_indices
import math
__all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN',
'infinity', 'π', 'isnegzero', 'non_zero', 'isposzero',
'exactly_equal', 'assert_exactly_equal', 'notequal',
'assert_finite', 'assert_non_zero', 'ispositive',
'assert_positive', 'isnegative', 'assert_negative', 'isintegral',
'assert_integral', 'isodd', 'iseven', "assert_iseven",
'assert_isinf', 'positive_mathematical_sign',
'assert_positive_mathematical_sign', 'negative_mathematical_sign',
'assert_negative_mathematical_sign', 'same_sign',
'assert_same_sign', 'float64',
'asarray', 'full', 'true', 'false', 'isnan']
def zero(shape, dtype):
"""
Returns a full 0 array of the given dtype.
This should be used in place of the literal "0" in the test suite, as the
spec does not require any behavior with Python literals (and in
particular, it does not specify how the integer 0 and the float 0.0 work
with type promotion).
To get -0, use -zero(dtype) (note that -0 is only defined for floating
point dtypes).
"""
return zeros(shape, dtype=dtype)
def one(shape, dtype):
"""
Returns a full 1 array of the given dtype.
This should be used in place of the literal "1" in the test suite, as the
spec does not require any behavior with Python literals (and in
particular, it does not specify how the integer 1 and the float 1.0 work
with type promotion).
To get -1, use -one(dtype).
"""
return ones(shape, dtype=dtype)
def NaN(shape, dtype):
"""
Returns a full nan array of the given dtype.
Note that this is only defined for floating point dtypes.
"""
if dtype not in [float32, float64]:
raise RuntimeError(f"Unexpected dtype {dtype} in NaN().")
return full(shape, nan, dtype=dtype)
def infinity(shape, dtype):
"""
Returns a full positive infinity array of the given dtype.
Note that this is only defined for floating point dtypes.
To get negative infinity, use -infinity(dtype).
"""
if dtype not in [float32, float64]:
raise RuntimeError(f"Unexpected dtype {dtype} in infinity().")
return full(shape, inf, dtype=dtype)
def π(shape, dtype):
"""
Returns a full π array of the given dtype.
Note that this function is only defined for floating point dtype.
To get rational multiples of π, use, e.g., 3*π(dtype)/2.
"""
if dtype not in [float32, float64]:
raise RuntimeError(f"Unexpected dtype {dtype} in π().")
return full(shape, pi, dtype=dtype)
def true(shape):
"""
Returns a full True array with dtype=bool.
"""
return full(shape, True, dtype=bool)
def false(shape):
"""
Returns a full False array with dtype=bool.
"""
return full(shape, False, dtype=bool)
def isnegzero(x):
"""
Returns a mask where x is -0. Is all False if x has integer dtype.
"""
# TODO: If copysign or signbit are added to the spec, use those instead.
shape = x.shape
dtype = x.dtype
if dh.is_int_dtype(dtype):
return false(shape)
return equal(divide(one(shape, dtype), x), -infinity(shape, dtype))
def isposzero(x):
"""
Returns a mask where x is +0 (but not -0). Is all True if x has integer dtype.
"""
# TODO: If copysign or signbit are added to the spec, use those instead.
shape = x.shape
dtype = x.dtype
if dh.is_int_dtype(dtype):
return true(shape)
return equal(divide(one(shape, dtype), x), infinity(shape, dtype))
def exactly_equal(x, y):
"""
Same as equal(x, y) except it gives True where both values are nan, and
distinguishes +0 and -0.
This function implicitly assumes x and y have the same shape and dtype.
"""
if x.dtype in [float32, float64]:
xnegzero = isnegzero(x)
ynegzero = isnegzero(y)
xposzero = isposzero(x)
yposzero = isposzero(y)
xnan = isnan(x)
ynan = isnan(y)
# (x == y OR x == y == NaN) AND xnegzero == ynegzero AND xposzero == y poszero
return logical_and(logical_and(
logical_or(equal(x, y), logical_and(xnan, ynan)),
equal(xnegzero, ynegzero)),
equal(xposzero, yposzero))
return equal(x, y)
def allclose(x, y, rel_tol=0.25, abs_tol=1, return_indices=False):
"""
Return True all elements of x and y are within tolerance
If return_indices=True, returns (False, (i, j)) when the arrays are not
close, where i and j are the indices into x and y of corresponding
non-close elements.
"""
for i, j in iter_indices(x.shape, y.shape):
i, j = i.raw, j.raw
a = x[i]
b = y[j]
if not (math.isfinite(a) and math.isfinite(b)):
# TODO: If a and b are both infinite, require the same type of infinity
continue
close = math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
if not close:
if return_indices:
return (False, (i, j))
return False
return True
def assert_allclose(x, y, rel_tol=0.25, abs_tol=1):
"""
Test that x and y are approximately equal to each other.
Also asserts that x and y have the same shape and dtype.
"""
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
c = allclose(x, y, rel_tol=rel_tol, abs_tol=abs_tol, return_indices=True)
if c is not True:
_, (i, j) = c
raise AssertionError(f"The input arrays are not close with {rel_tol = } and {abs_tol = } at indices {i = } and {j = }")
def notequal(x, y):
"""
Same as not_equal(x, y) except it gives False when both values are nan.
Note: this function does NOT distinguish +0 and -0.
This function implicitly assumes x and y have the same shape and dtype.
"""
if x.dtype in [float32, float64]:
xnan = isnan(x)
ynan = isnan(y)
both_nan = logical_and(xnan, ynan)
# NOT both nan AND (both nan OR x != y)
return logical_and(logical_not(both_nan), not_equal(x, y))
return not_equal(x, y)
def assert_exactly_equal(x, y):
"""
Test that the arrays x and y are exactly equal.
If x and y do not have the same shape and dtype, they are not considered
equal.
"""
assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})"
assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})"
assert all(exactly_equal(x, y)), "The input arrays have different values"
def assert_finite(x):
"""
Test that the array x is finite
"""
assert all(isfinite(x)), "The input array is not finite"
def non_zero(x):
return not_equal(x, zero(x.shape, x.dtype))
def assert_non_zero(x):
assert all(non_zero(x)), "The input array is not nonzero"
def ispositive(x):
return greater(x, zero(x.shape, x.dtype))
def assert_positive(x):
assert all(ispositive(x)), "The input array is not positive"
def isnegative(x):
return less(x, zero(x.shape, x.dtype))
def assert_negative(x):
assert all(isnegative(x)), "The input array is not negative"
def inrange(x, a, b, epsilon=0, open=False):
"""
Returns a mask for values of x in the range [a-epsilon, a+epsilon], inclusive
If open=True, the range is (a-epsilon, a+epsilon) (i.e., not inclusive).
"""
eps = full(x.shape, epsilon, dtype=x.dtype)
l = less if open else less_equal
return logical_and(l(a-eps, x), l(x, b+eps))
def isintegral(x):
"""
Returns a mask on x where the values are integral
x is integral if its dtype is an integer dtype, or if it is a floating
point value that can be exactly represented as an integer.
"""
if x.dtype in [int8, int16, int32, int64, uint8, uint16, uint32, uint64]:
return full(x.shape, True, dtype=bool)
elif x.dtype in [float32, float64]:
return equal(remainder(x, one(x.shape, x.dtype)), zero(x.shape, x.dtype))
else:
return full(x.shape, False, dtype=bool)
def assert_integral(x):
"""
Check that x has only integer values
"""
assert all(isintegral(x)), "The input array has nonintegral values"
def isodd(x):
return logical_and(
isintegral(x),
equal(
remainder(x, 2*one(x.shape, x.dtype)),
one(x.shape, x.dtype)))
def iseven(x):
return logical_and(
isintegral(x),
equal(
remainder(x, 2*one(x.shape, x.dtype)),
zero(x.shape, x.dtype)))
def assert_iseven(x):
"""
Check that x is an even integer
"""
assert all(iseven(x)), "The input array is not even"
def assert_isinf(x):
"""
Check that x is an infinity
"""
assert all(isinf(x)), "The input array is not infinite"
def positive_mathematical_sign(x):
"""
Check if x has a positive "mathematical sign"
The "mathematical sign" here means the sign bit is 0. This includes 0,
positive finite numbers, and positive infinity. It does not include any
nans, as signed nans are not required by the spec.
"""
z = zero(x.shape, x.dtype)
return logical_or(greater(x, z), isposzero(x))
def assert_positive_mathematical_sign(x):
assert all(positive_mathematical_sign(x)), "The input arrays do not have a positive mathematical sign"
def negative_mathematical_sign(x):
"""
Check if x has a negative "mathematical sign"
The "mathematical sign" here means the sign bit is 1. This includes -0,
negative finite numbers, and negative infinity. It does not include any
nans, as signed nans are not required by the spec.
"""
z = zero(x.shape, x.dtype)
if x.dtype in [float32, float64]:
return logical_or(less(x, z), isnegzero(x))
return less(x, z)
def assert_negative_mathematical_sign(x):
assert all(negative_mathematical_sign(x)), "The input arrays do not have a negative mathematical sign"
def same_sign(x, y):
"""
Check if x and y have the "same sign"
x and y have the same sign if they are both nonnegative or both negative.
For the purposes of this function 0 and 1 have the same sign and -0 and -1
have the same sign. The value of this function is False if either x or y
is nan, as signed nans are not required by the spec.
"""
return logical_or(
logical_and(positive_mathematical_sign(x), positive_mathematical_sign(y)),
logical_and(negative_mathematical_sign(x), negative_mathematical_sign(y)))
def assert_same_sign(x, y):
assert all(same_sign(x, y)), "The input arrays do not have the same sign"