-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_basic_ops.py
341 lines (252 loc) · 9.12 KB
/
test_basic_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
#
import pytest
import jax
import jax.numpy as jnp
import numpy
from infrastructure import verify_module
def test_abs_op():
def module_abs(a):
return jnp.abs(a)
verify_module(module_abs, [(3, 3)])
verify_module(module_abs, [(3, 3, 3)])
# Broadcasted values are incorrect
@pytest.mark.skip("Broadcasted values are incorrect")
def test_broadcast_op():
def module_broadcast(a):
return jnp.broadcast_to(a, (2, 4))
verify_module(module_broadcast, [(2, 1)])
def test_cbrt_op():
def module_cbrt(a):
return jax.lax.cbrt(a)
verify_module(
module_cbrt, [(3, 3)], required_atol=2e-2
) # ATOL is 0.010040640830993652
verify_module(module_cbrt, [(3, 3, 3)], required_atol=2e-2)
def test_concat_op():
def module_concat_dim_0(x, y):
return jnp.concatenate([x, y], axis=0)
def module_concat_dim_1(x, y):
return jnp.concatenate([x, y], axis=1)
def module_concat_dim_2(x, y):
return jnp.concatenate([x, y], axis=2)
def module_concat_dim_3(x, y):
return jnp.concatenate([x, y], axis=3)
verify_module(module_concat_dim_0, [(32, 32), (64, 32)]) # output shape: (96, 32)
verify_module(
module_concat_dim_0, [(32, 32, 32), (64, 32, 32)]
) # output shape: (96, 32, 32)
verify_module(module_concat_dim_1, [(32, 32), (32, 64)]) # output shape: (32, 96)
verify_module(
module_concat_dim_1, [(32, 32, 32), (32, 32, 32)]
) # output shape: (32, 64, 32)
verify_module(
module_concat_dim_2, [(32, 32, 32), (32, 32, 64)]
) # output shape: (32, 32, 96)
verify_module(
module_concat_dim_2, [(32, 32, 32, 32), (32, 32, 64, 32)]
) # output shape: (32, 32, 96, 32)
verify_module(
module_concat_dim_3, [(32, 32, 32, 32), (32, 32, 32, 64)]
) # output shape: (32, 32, 32, 96)
# error: 'ttir.constant' op failed to verify that all of {value, result} have same shape
@pytest.mark.skip(
"Index is out of bounds for the rank, should be between 0 and 0 however is 18446744073709551615"
)
def test_constant_op():
def module_constant_zeros(a):
zeros = jnp.zeros(a.shape)
return zeros
def module_constant_ones(a):
ones = jnp.ones(a.shape)
return ones
def module_constant_multi(a):
multi = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
return multi
verify_module(module_constant_zeros, [(3, 3)])
verify_module(module_constant_ones, [(3, 3)])
verify_module(module_constant_multi, [(3, 3)])
def test_convert_op():
def module_convert(a):
return jax.lax.convert_element_type(a, jnp.bfloat16)
verify_module(module_convert, [(2, 2)])
verify_module(module_convert, [(4, 4, 4)])
def test_div_op():
def module_div(a, b):
return a / b
verify_module(module_div, [(3, 3), (3, 3)])
verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2)
@pytest.mark.skip("VHLO Legalization failed.")
def test_dot_general_op():
def module_dot_general(a, b):
return jnp.dot(a, b)
verify_module(module_dot_general, [(2, 1), (1, 2)])
verify_module(module_dot_general, [(1, 2), (2, 1)])
# Exponential generate slightly different values, so using higher ATOL value.
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_exp_op():
def module_exp(a):
return jnp.exp(a)
verify_module(module_exp, [(3, 3)], required_atol=20e-2)
verify_module(module_exp, [(3, 3, 3)], required_atol=25e-2)
def test_maximum_op():
def module_maximum(a, b):
return jnp.maximum(a, b)
verify_module(module_maximum, [(3, 3), (3, 3)])
verify_module(module_maximum, [(3, 3, 3), (3, 3, 3)])
def test_multiply_op():
def module_multiply(a, b):
return a * b
verify_module(module_multiply, [(3, 3), (3, 3)])
verify_module(module_multiply, [(3, 3, 3), (3, 3, 3)])
def test_negate_op():
def module_negate(a):
return -a
verify_module(module_negate, [(3, 3)])
verify_module(module_negate, [(3, 3, 3)])
# Reduce is failing due to error in constant.
@pytest.mark.skip("keepdim=False is not supported")
def test_reduce_op():
def module_reduce_max(a):
return jnp.max(a)
def module_reduce_sum(a):
return jnp.sum(a)
verify_module(module_reduce_max, [(3, 3)])
verify_module(module_reduce_max, [(3, 3, 3)])
verify_module(module_reduce_sum, [(3, 3)])
verify_module(module_reduce_sum, [(3, 3, 3)])
def test_rsqrt_op():
def module_rsqrt(a):
return jax.lax.rsqrt(a)
verify_module(module_rsqrt, [(3, 3)])
verify_module(module_rsqrt, [(3, 3, 3)])
# Needs to have a bigger atol due to inaccuracies in the exp op on tt-metal
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_expm1_op():
def module_expm1(a):
return jax.lax.expm1(a)
verify_module(module_expm1, [(3, 3)], required_atol=20e-2)
verify_module(module_expm1, [(3, 3, 3)], required_atol=20e-2)
def test_log1p_op():
def module_log1p(a):
return jax.lax.log1p(a)
verify_module(module_log1p, [(3, 3)], required_atol=2e-2)
verify_module(module_log1p, [(3, 3, 3)], required_atol=2e-2)
def test_sign_op():
def module_sign(a):
return jax.lax.sign(a)
verify_module(module_sign, [(3, 3)])
verify_module(module_sign, [(3, 3, 3)])
def test_sqrt_op():
def module_sqrt(a):
return jnp.sqrt(a)
verify_module(module_sqrt, [(3, 3)])
verify_module(module_sqrt, [(3, 3, 3)])
def test_sub_op():
def module_sub(a, b):
return a - b
verify_module(module_sub, [(3, 3), (3, 3)])
verify_module(module_sub, [(3, 3, 3), (3, 3, 3)])
def test_transpose_op_2d():
def module_transpose(a):
return jnp.transpose(a)
verify_module(module_transpose, [(3, 3)])
@pytest.mark.skip(
"Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73"
)
def test_scalar_type():
def module_scalar_type(a):
return a.shape[0]
verify_module(module_scalar_type, [(3, 3)])
# Transpose op failing for higher ranks/dimensions.
@pytest.mark.skip("Transpose op failing for higher ranks/dimensions.")
def test_transpose_op_3d():
def module_transpose(a):
return jnp.transpose(a)
verify_module(module_transpose, [(3, 3, 3)])
dim0_cases = []
for begin in numpy.arange(10).tolist():
for end in numpy.arange(90, 100).tolist():
dim0_cases.append((begin, end, 0))
dim1_cases = []
for begin in numpy.arange(10).tolist():
for end in numpy.arange(90, 100).tolist():
dim1_cases.append((begin, end, 1))
dim2_cases = []
for begin in numpy.arange(0, 64, 32).tolist():
for end in numpy.arange(64, 128, 32).tolist():
dim2_cases.append((begin, end, 2))
dim3_cases = []
for begin in numpy.arange(0, 64, 32).tolist():
for end in numpy.arange(64, 128, 32).tolist():
dim3_cases.append((begin, end, 3))
@pytest.mark.parametrize(
"begin, end, dim", [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases]
)
@pytest.mark.skip("Requires tt-metal uplift.")
def test_slice(begin, end, dim):
def module_slice(a):
if dim == 0:
return a[begin:end, :, :, :]
elif dim == 1:
return a[:, begin:end, :, :]
elif dim == 2:
return a[:, :, begin:end, :]
else:
return a[:, :, :, begin:end]
shape = [10, 10, 10, 10]
shape[dim] = 128
verify_module(module_slice, [shape])
@pytest.mark.parametrize(
"input_shapes",
[
[(32, 32), (32, 32)],
pytest.param(
[(3, 3), (3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
pytest.param(
[(3, 3, 3), (3, 3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
],
)
def test_remainder_op_lax(input_shapes):
def module_remainder_lax(a, b):
return jax.lax.rem(a, b)
verify_module(module_remainder_lax, input_shapes, required_atol=0.02)
@pytest.mark.parametrize(
"input_shapes",
[
pytest.param(
[(32, 32), (32, 32)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/71"
),
),
pytest.param(
[(3, 3), (3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
pytest.param(
[(3, 3, 3), (3, 3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
],
)
def test_remainder_op_jnp(input_shapes):
# `jnp.remainder` generates a more complex stablehlo graph than `jax.lax.rem` with
# implicit broadcasts, etc. That's why we have both.
def module_remainder_jnp(a, b):
return jnp.remainder(a, b)
verify_module(module_remainder_jnp, input_shapes, required_atol=0.02)