Skip to content

Commit

Permalink
Merge pull request #59 from tenstorrent/ajakovljevic/adding_expm1_log…
Browse files Browse the repository at this point in the history
…1p_sign_op_tests

Added expm1, log1p and sign tests for jax
  • Loading branch information
ajakovljevicTT authored Nov 12, 2024
2 parents 4d10fdf + 955e5d9 commit 5476109
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def module_dot_general(a, b):


# Exponential generate slightly different values, so using higher ATOL value.
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_exp_op():
def module_exp(a):
return jnp.exp(a)
Expand Down Expand Up @@ -174,6 +175,32 @@ def module_rsqrt(a):
verify_module(module_rsqrt, [(3, 3, 3)])


# Needs to have a bigger atol due to inaccuracies in the exp op on tt-metal
# see tt-mlir issue https://github.com/tenstorrent/tt-mlir/issues/1199)
def test_expm1_op():
def module_expm1(a):
return jax.lax.expm1(a)

verify_module(module_expm1, [(3, 3)], required_atol=20e-2)
verify_module(module_expm1, [(3, 3, 3)], required_atol=20e-2)


def test_log1p_op():
def module_log1p(a):
return jax.lax.log1p(a)

verify_module(module_log1p, [(3, 3)], required_atol=2e-2)
verify_module(module_log1p, [(3, 3, 3)], required_atol=2e-2)


def test_sign_op():
def module_sign(a):
return jax.lax.sign(a)

verify_module(module_sign, [(3, 3)])
verify_module(module_sign, [(3, 3, 3)])


def test_sqrt_op():
def module_sqrt(a):
return jnp.sqrt(a)
Expand Down Expand Up @@ -230,7 +257,6 @@ def module_transpose(a):
@pytest.mark.parametrize(
"begin, end, dim", [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases]
)

@pytest.mark.skip("Requires tt-metal uplift.")
def test_slice(begin, end, dim):
def module_slice(a):
Expand Down

0 comments on commit 5476109

Please sign in to comment.