Skip to content

Commit ec123bf

Browse files
committed
Fixed op tests
1 parent cd3e64e commit ec123bf

16 files changed

+135
-170
lines changed

tests/jax/graphs/test_example_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def example_graph(x: jax.Array, y: jax.Array) -> jax.Array:
1616

1717

1818
@pytest.mark.push
19-
@pytest.mark.nightly
2019
@pytest.mark.parametrize(
2120
["x_shape", "y_shape"],
2221
[

tests/jax/multichip/manual/all_gather.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import jax.numpy as jnp
77
import pytest
88
from infra import make_partition_spec, run_multichip_test_with_random_inputs
9-
from utils import TestCategory, compile_fail
9+
from utils import compile_fail
1010

1111

12-
@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
1312
@pytest.mark.parametrize(
1413
("x_shape", "mesh_shape", "axis_names"), [((8192, 784), (2,), ("batch",))]
1514
)

tests/jax/multichip/manual/data_paralelism.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import jax.numpy as jnp
77
import pytest
88
from infra import make_partition_spec, run_multichip_test_with_random_inputs
9-
from utils import TestCategory, compile_fail
9+
from utils import compile_fail
1010

1111

12-
@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
1312
@pytest.mark.parametrize(
1413
[
1514
"batch_shape",

tests/jax/multichip/manual/psum.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import jax.numpy as jnp
77
import pytest
88
from infra import make_partition_spec, run_multichip_test_with_random_inputs
9-
from utils import TestCategory, compile_fail
9+
from utils import compile_fail
1010

1111

12-
@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
1312
@pytest.mark.parametrize(
1413
["batch_shape", "W1_shape", "B1_shape", "mesh_shape", "axis_names"],
1514
[

tests/jax/multichip/manual/psum_scatter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import jax.numpy as jnp
77
import pytest
88
from infra import make_partition_spec, run_multichip_test_with_random_inputs
9-
from utils import TestCategory, compile_fail
9+
from utils import compile_fail
1010

1111

12-
@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
1312
@pytest.mark.parametrize(
1413
["batch_shape", "W1_shape", "B1_shape", "mesh_shape", "axis_names"],
1514
[

tests/jax/multichip/manual/unary_eltwise.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import jax.numpy as jnp
77
import pytest
88
from infra import make_partition_spec, run_multichip_test_with_random_inputs
9-
from utils import TestCategory, compile_fail
9+
from utils import compile_fail
1010

1111

12-
@pytest.mark.record_test_properties(test_category=TestCategory.MULTICHIP_TEST.value)
1312
@pytest.mark.parametrize(
1413
("x_shape", "mesh_shape", "axis_names"), [((256, 256), (1, 2), ("x", "y"))]
1514
)

tests/jax/ops/test_and.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
1-
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
1+
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
# TODO add tests for `stablehlo.and`.
5+
import jax
6+
import jax.numpy as jnp
7+
import pytest
8+
from infra import random_tensor, run_op_test
9+
from utils import TestCategory, convert_output_to_bfloat16
10+
11+
12+
@pytest.mark.push
13+
@pytest.mark.nightly
14+
@pytest.mark.record_test_properties(
15+
test_category=TestCategory.OP_TEST.value,
16+
jax_op_name="jax.numpy.logical_and",
17+
shlo_op_name="stablehlo.and{LOGICAL}",
18+
)
19+
@pytest.mark.parametrize(
20+
["shape"],
21+
[
22+
[(32, 32)],
23+
[(64, 64)],
24+
],
25+
ids=lambda val: f"{val}",
26+
)
27+
def test_logical_and(shape: tuple):
28+
@convert_output_to_bfloat16
29+
def logical_and(a: jax.Array, b: jax.Array) -> jax.Array:
30+
return jnp.logical_and(a, b)
31+
32+
lhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=3)
33+
rhs = random_tensor(shape, jnp.int32, minval=0, maxval=2, random_seed=6)
34+
run_op_test(logical_and, [lhs, rhs])

tests/jax/ops/test_logical_and.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

tests/jax/ops/test_logical_not.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

tests/jax/ops/test_logical_or.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)