Skip to content

Commit

Permalink
Fixed op tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Mar 10, 2025
1 parent cd3e64e commit ec123bf
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 170 deletions.
1 change: 0 additions & 1 deletion tests/jax/graphs/test_example_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def example_graph(x: jax.Array, y: jax.Array) -> jax.Array:


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.parametrize(
["x_shape", "y_shape"],
[
Expand Down
3 changes: 1 addition & 2 deletions tests/jax/multichip/manual/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import jax.numpy as jnp
import pytest
from infra import make_partition_spec, run_multichip_test_with_random_inputs
from utils import TestCategory, compile_fail
from utils import compile_fail


@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
@pytest.mark.parametrize(
("x_shape", "mesh_shape", "axis_names"), [((8192, 784), (2,), ("batch",))]
)
Expand Down
3 changes: 1 addition & 2 deletions tests/jax/multichip/manual/data_paralelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import jax.numpy as jnp
import pytest
from infra import make_partition_spec, run_multichip_test_with_random_inputs
from utils import TestCategory, compile_fail
from utils import compile_fail


@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
@pytest.mark.parametrize(
[
"batch_shape",
Expand Down
3 changes: 1 addition & 2 deletions tests/jax/multichip/manual/psum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import jax.numpy as jnp
import pytest
from infra import make_partition_spec, run_multichip_test_with_random_inputs
from utils import TestCategory, compile_fail
from utils import compile_fail


@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
@pytest.mark.parametrize(
["batch_shape", "W1_shape", "B1_shape", "mesh_shape", "axis_names"],
[
Expand Down
3 changes: 1 addition & 2 deletions tests/jax/multichip/manual/psum_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import jax.numpy as jnp
import pytest
from infra import make_partition_spec, run_multichip_test_with_random_inputs
from utils import TestCategory, compile_fail
from utils import compile_fail


@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
@pytest.mark.parametrize(
["batch_shape", "W1_shape", "B1_shape", "mesh_shape", "axis_names"],
[
Expand Down
3 changes: 1 addition & 2 deletions tests/jax/multichip/manual/unary_eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import jax.numpy as jnp
import pytest
from infra import make_partition_spec, run_multichip_test_with_random_inputs
from utils import TestCategory, compile_fail
from utils import compile_fail


@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
@pytest.mark.parametrize(
("x_shape", "mesh_shape", "axis_names"), [((256, 256), (1, 2), ("x", "y"))]
)
Expand Down
33 changes: 31 additions & 2 deletions tests/jax/ops/test_and.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# TODO add tests for `stablehlo.and`.
import jax
import jax.numpy as jnp
import pytest
from infra import random_tensor, run_op_test
from utils import TestCategory, convert_output_to_bfloat16


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(
test_category=TestCategory.OP_TEST.value,
jax_op_name="jax.numpy.logical_and",
shlo_op_name="stablehlo.and{LOGICAL}",
)
@pytest.mark.parametrize(
["shape"],
[
[(32, 32)],
[(64, 64)],
],
ids=lambda val: f"{val}",
)
def test_logical_and(shape: tuple):
@convert_output_to_bfloat16
def logical_and(a: jax.Array, b: jax.Array) -> jax.Array:
return jnp.logical_and(a, b)

lhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=3)
rhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=6)
run_op_test(logical_and, [lhs, rhs])
36 changes: 0 additions & 36 deletions tests/jax/ops/test_logical_and.py

This file was deleted.

35 changes: 0 additions & 35 deletions tests/jax/ops/test_logical_not.py

This file was deleted.

36 changes: 0 additions & 36 deletions tests/jax/ops/test_logical_or.py

This file was deleted.

36 changes: 0 additions & 36 deletions tests/jax/ops/test_logical_xor.py

This file was deleted.

32 changes: 30 additions & 2 deletions tests/jax/ops/test_not.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,33 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# TODO add tests for `stablehlo.not`.
import jax
import jax.numpy as jnp
import pytest
from infra import random_tensor, run_op_test
from utils import TestCategory, convert_output_to_bfloat16


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(
test_category=TestCategory.OP_TEST.value,
jax_op_name="jax.numpy.logical_not",
shlo_op_name="stablehlo.not{LOGICAL}",
)
@pytest.mark.parametrize(
["shape"],
[
[(32, 32)],
[(64, 64)],
],
ids=lambda val: f"{val}",
)
def test_logical_not(shape: tuple):
@convert_output_to_bfloat16
def logical_not(a: jax.Array) -> jax.Array:
return jnp.logical_not(a)

input = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=3)
run_op_test(logical_not, [input])
33 changes: 31 additions & 2 deletions tests/jax/ops/test_or.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# TODO add tests for `stablehlo.or`.
import jax
import jax.numpy as jnp
import pytest
from infra import random_tensor, run_op_test
from utils import TestCategory, convert_output_to_bfloat16


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(
test_category=TestCategory.OP_TEST.value,
jax_op_name="jax.numpy.logical_or",
shlo_op_name="stablehlo.or{LOGICAL}",
)
@pytest.mark.parametrize(
["shape"],
[
[(32, 32)],
[(64, 64)],
],
ids=lambda val: f"{val}",
)
def test_logical_or(shape: tuple):
@convert_output_to_bfloat16
def logical_or(a: jax.Array, b: jax.Array) -> jax.Array:
return jnp.logical_or(a, b)

lhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=3)
rhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=6)
run_op_test(logical_or, [lhs, rhs])
14 changes: 7 additions & 7 deletions tests/jax/ops/test_tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
import jax.numpy as jnp
import pytest
from infra import run_op_test_with_random_inputs
from utils import record_unary_op_test_properties
from utils import TestCategory


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(
test_category=TestCategory.OP_TEST.value,
jax_op_name="jax.numpy.subtract",
shlo_op_name="stablehlo.subtract",
)
@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}")
def test_tanh(x_shape: tuple, record_tt_xla_property: Callable):
def test_tanh(x_shape: tuple):
def tanh(x: jax.Array) -> jax.Array:
return jnp.tanh(x)

record_unary_op_test_properties(
record_tt_xla_property,
"jax.numpy.tanh",
"stablehlo.tanh",
)
run_op_test_with_random_inputs(tanh, [x_shape])
33 changes: 31 additions & 2 deletions tests/jax/ops/test_xor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# TODO add tests for `stablehlo.xor`.
import jax
import jax.numpy as jnp
import pytest
from infra import random_tensor, run_op_test
from utils import TestCategory, convert_output_to_bfloat16


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(
test_category=TestCategory.OP_TEST.value,
jax_op_name="jax.numpy.logical_xor",
shlo_op_name="stablehlo.xor{LOGICAL}",
)
@pytest.mark.parametrize(
["shape"],
[
[(32, 32)],
[(64, 64)],
],
ids=lambda val: f"{val}",
)
def test_logical_xor(shape: tuple):
@convert_output_to_bfloat16
def logical_xor(a: jax.Array, b: jax.Array) -> jax.Array:
return jnp.logical_xor(a, b)

lhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=3)
rhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=6)
run_op_test(logical_xor, [lhs, rhs])
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class TestCategory(Enum):
OP_TEST = "op_test"
GRAPH_TEST = "graph_test"
MODEL_TEST = "model_test"
MULTICHIP_TEST = "multichip_test"
OTHER = "other"


Expand Down

0 comments on commit ec123bf

Please sign in to comment.