Skip to content

Commit

Permalink
Merge pull request #51 from tenstorrent/kmitrovic/stablehlo_remainder_op
Browse files Browse the repository at this point in the history
Added test for jax.remainder op
  • Loading branch information
kmitrovicTT authored Nov 16, 2024
2 parents ade7d4c + e8db19e commit 56cc716
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,60 @@ def module_slice(a):
shape = [10, 10, 10, 10]
shape[dim] = 128
verify_module(module_slice, [shape])


@pytest.mark.parametrize(
"input_shapes",
[
[(32, 32), (32, 32)],
pytest.param(
[(3, 3), (3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
pytest.param(
[(3, 3, 3), (3, 3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
],
)
def test_remainder_op_lax(input_shapes):
def module_remainder_lax(a, b):
return jax.lax.rem(a, b)

verify_module(module_remainder_lax, input_shapes, required_atol=0.02)


@pytest.mark.parametrize(
"input_shapes",
[
pytest.param(
[(32, 32), (32, 32)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/71"
),
),
pytest.param(
[(3, 3), (3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
pytest.param(
[(3, 3, 3), (3, 3, 3)],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-xla/issues/70"
),
),
],
)
def test_remainder_op_jnp(input_shapes):
# `jnp.remainder` generates a more complex stablehlo graph than `jax.lax.rem` with
# implicit broadcasts, etc. That's why we have both.
def module_remainder_jnp(a, b):
return jnp.remainder(a, b)

verify_module(module_remainder_jnp, input_shapes, required_atol=0.02)

0 comments on commit 56cc716

Please sign in to comment.