Skip to content

Commit

Permalink
Merge pull request #75 from tenstorrent/ajakovljevic/scalar_test_fix
Browse files Browse the repository at this point in the history
Fixes for test infrastructure to support scalars
  • Loading branch information
ajakovljevicTT authored Nov 25, 2024
2 parents 5aadbf1 + 42b8273 commit 84c08b0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
10 changes: 10 additions & 0 deletions tests/TTIR/test_basic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ def module_transpose(a):
verify_module(module_transpose, [(3, 3)])


@pytest.mark.skip(
"Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73"
)
def test_scalar_type():
def module_scalar_type(a):
return a.shape[0]

verify_module(module_scalar_type, [(3, 3)])


# Transpose op failing for higher ranks/dimensions.
@pytest.mark.skip("Transpose op failing for higher ranks/dimensions.")
def test_transpose_op_3d():
Expand Down
24 changes: 22 additions & 2 deletions tests/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@
#
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
import jax
import jax.numpy as jnp


@contextmanager
def run_on_cpu():
devices = jax.local_devices(backend="cpu")
assert len(devices) > 0
cpu = devices[0]

with jax.default_device(cpu):
yield


# TODO(issue #80): Make this creation more explicit regarding the originating devices.
def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32):
device_cpu = jax.devices("cpu")[0]
with jax.default_device(device_cpu):
Expand All @@ -27,6 +39,14 @@ def compare_tensor_to_golden(
tensor, golden, required_pcc=0.99, required_atol=1e-2, assert_on_error=True
):
ret = True

# TODO (issue #81): Remove these reshapes once the PJRT can handle scalars.
if tensor.ndim == 0:
tensor = tensor.reshape((1,))
if golden.ndim == 0:
with run_on_cpu():
golden = golden.reshape((1,))

if tensor.device != golden.device:
tensor = jax.device_put(tensor, golden.device)

Expand Down Expand Up @@ -66,6 +86,6 @@ def verify_module(
tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs]
graph = jax.jit(module)
res = graph(*tt_inputs)
res_cpu = graph(*cpu_inputs)

with run_on_cpu():
res_cpu = graph(*cpu_inputs)
compare_tensor_to_golden(res, res_cpu, required_pcc, required_atol)

0 comments on commit 84c08b0

Please sign in to comment.