We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3496e68 commit 5d4e9e0Copy full SHA for 5d4e9e0
tests/link/jax/test_tensor_basic.py
@@ -29,7 +29,7 @@ def test_jax_Alloc():
29
x = ptb.AllocEmpty("float32")(2, 3)
30
31
def compare_shape_dtype(x, y):
32
- np.testing.assert_array_equal(x, y, strict=True)
+ assert x.shape == y.shape and x.dtype == y.dtype
33
34
compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype)
35
0 commit comments