diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index 413b5c0f..58dff4e0 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -224,6 +224,16 @@ def module_transpose(a): verify_module(module_transpose, [(3, 3)]) +@pytest.mark.skip( + "Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73" +) +def test_scalar_type(): + def module_scalar_type(a): + return a.shape[0] + + verify_module(module_scalar_type, [(3, 3)]) + + # Transpose op failing for higher ranks/dimensions. @pytest.mark.skip("Transpose op failing for higher ranks/dimensions.") def test_transpose_op_3d(): diff --git a/tests/infrastructure.py b/tests/infrastructure.py index dd0ce931..42c16c97 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -2,10 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager import jax import jax.numpy as jnp +@contextmanager +def run_on_cpu(): + devices = jax.local_devices(backend="cpu") + assert len(devices) > 0 + cpu = devices[0] + + with jax.default_device(cpu): + yield + + +# TODO(issue #80): Make this creation more explicit regarding the originating devices. def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32): device_cpu = jax.devices("cpu")[0] with jax.default_device(device_cpu): @@ -27,6 +39,14 @@ def compare_tensor_to_golden( tensor, golden, required_pcc=0.99, required_atol=1e-2, assert_on_error=True ): ret = True + + # TODO (issue #81): Remove these reshapes once the PJRT can handle scalars. + if tensor.ndim == 0: + tensor = tensor.reshape((1,)) + if golden.ndim == 0: + with run_on_cpu(): + golden = golden.reshape((1,)) + if tensor.device != golden.device: tensor = jax.device_put(tensor, golden.device) @@ -66,6 +86,6 @@ def verify_module( tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs] graph = jax.jit(module) res = graph(*tt_inputs) - res_cpu = graph(*cpu_inputs) - + with run_on_cpu(): + res_cpu = graph(*cpu_inputs) compare_tensor_to_golden(res, res_cpu, required_pcc, required_atol)