From 2e4a5e7cfd68845f42a72122891740e9175225ab Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 09:33:05 +0000 Subject: [PATCH 1/7] Added quick scalar support and test --- tests/TTIR/test_basic_ops.py | 8 ++++++++ tests/infrastructure.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index e840571c..611ba7c0 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -7,6 +7,8 @@ import jax import jax.numpy as jnp import numpy +from jax import export + from infrastructure import verify_module @@ -223,6 +225,12 @@ def module_transpose(a): verify_module(module_transpose, [(3, 3)]) +def test_shape_scalar(): + def module_shape_scalar(a): + return a.shape[0] + + verify_module(module_shape_scalar, [(3, 3)]) + # Transpose op failing for higher ranks/dimensions. @pytest.mark.skip("Transpose op failing for higher ranks/dimensions.") diff --git a/tests/infrastructure.py b/tests/infrastructure.py index dd0ce931..1d7c6960 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -42,7 +42,7 @@ def compare_tensor_to_golden( if assert_on_error: assert ret, f"PCC is {pcc} which is less than {required_pcc}" - atol = jnp.max(jnp.abs(tensor - golden)) + atol = jnp.abs(tensor - golden) if len(tensor.shape) == 0 else jnp.max(jnp.abs(tensor - golden)) ret = ret and atol <= required_atol if assert_on_error: assert ret, f"ATOL is {atol} which is greater than {required_atol}" From bbefab7c7c84a69cfff9e39b96a79d82aed9bfce Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 09:34:38 +0000 Subject: [PATCH 2/7] Added skip test --- tests/TTIR/test_basic_ops.py | 4 ++++ tests/infrastructure.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index 611ba7c0..2551930f 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -225,6 +225,10 @@ def module_transpose(a): verify_module(module_transpose, [(3, 3)]) + +@pytest.mark.skip( + "Scalar ops currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73" +) def test_shape_scalar(): def module_shape_scalar(a): return a.shape[0] diff --git a/tests/infrastructure.py b/tests/infrastructure.py index 1d7c6960..b19d3934 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -42,7 +42,11 @@ def compare_tensor_to_golden( if assert_on_error: assert ret, f"PCC is {pcc} which is less than {required_pcc}" - atol = jnp.abs(tensor - golden) if len(tensor.shape) == 0 else jnp.max(jnp.abs(tensor - golden)) + atol = ( + jnp.abs(tensor - golden) + if len(tensor.shape) == 0 + else jnp.max(jnp.abs(tensor - golden)) + ) ret = ret and atol <= required_atol if assert_on_error: assert ret, f"ATOL is {atol} which is greater than {required_atol}" From b579ccd875b2af71bee2a9a271affd84afe1a2df Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 09:38:25 +0000 Subject: [PATCH 3/7] Deleted unwated libraries --- tests/TTIR/test_basic_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index 2551930f..97885cea 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -7,8 +7,6 @@ import jax import jax.numpy as jnp import numpy -from jax import export - from infrastructure import verify_module From 624415407e7d05e1bf0bca51b93d01762632862d Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 17:22:14 +0000 Subject: [PATCH 4/7] Add force cpu --- tests/infrastructure.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/infrastructure.py b/tests/infrastructure.py index b19d3934..f8407a9d 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -27,6 +27,13 @@ def compare_tensor_to_golden( tensor, golden, required_pcc=0.99, required_atol=1e-2, assert_on_error=True ): ret = True + + if tensor.ndim == 0: + tensor = tensor.reshape((1,)) + if golden.ndim == 0: + with jax.default_device(jax.local_devices(backend="cpu")[0]): + golden = golden.reshape((1,)) + if tensor.device != golden.device: tensor = jax.device_put(tensor, golden.device) @@ -41,12 +48,7 @@ def compare_tensor_to_golden( ) if assert_on_error: assert ret, f"PCC is {pcc} which is less than {required_pcc}" - - atol = ( - jnp.abs(tensor - golden) - if len(tensor.shape) == 0 - else jnp.max(jnp.abs(tensor - golden)) - ) + atol = jnp.max(jnp.abs(tensor - golden)) ret = ret and atol <= required_atol if assert_on_error: assert ret, f"ATOL is {atol} which is greater than {required_atol}" @@ -70,6 +72,6 @@ def verify_module( tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs] graph = jax.jit(module) res = graph(*tt_inputs) - res_cpu = graph(*cpu_inputs) - + with jax.default_device(jax.local_devices(backend="cpu")[0]): + res_cpu = graph(*cpu_inputs) compare_tensor_to_golden(res, res_cpu, required_pcc, required_atol) From bdb8276b22ceb954dc1ccc27c122d4680c35d770 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Tue, 19 Nov 2024 17:23:47 +0000 Subject: [PATCH 5/7] Format change --- tests/infrastructure.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/infrastructure.py b/tests/infrastructure.py index f8407a9d..9aadad74 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -48,6 +48,7 @@ def compare_tensor_to_golden( ) if assert_on_error: assert ret, f"PCC is {pcc} which is less than {required_pcc}" + atol = jnp.max(jnp.abs(tensor - golden)) ret = ret and atol <= required_atol if assert_on_error: From 52e5c3b43576fa4535bb34c1320ff2718531c138 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Wed, 20 Nov 2024 14:45:47 +0000 Subject: [PATCH 6/7] Addressed comments --- tests/infrastructure.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/infrastructure.py b/tests/infrastructure.py index 9aadad74..ee654184 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -2,10 +2,21 @@ # # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager import jax import jax.numpy as jnp +@contextmanager +def run_on_cpu(): + devices = jax.local_devices(backend="cpu") + assert len(devices) > 0 + cpu = devices[0] + + with jax.default_device(cpu): + yield + + def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32): device_cpu = jax.devices("cpu")[0] with jax.default_device(device_cpu): @@ -31,7 +42,7 @@ def compare_tensor_to_golden( if tensor.ndim == 0: tensor = tensor.reshape((1,)) if golden.ndim == 0: - with jax.default_device(jax.local_devices(backend="cpu")[0]): + with run_on_cpu(): golden = golden.reshape((1,)) if tensor.device != golden.device: @@ -73,6 +84,6 @@ def verify_module( tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs] graph = jax.jit(module) res = graph(*tt_inputs) - with jax.default_device(jax.local_devices(backend="cpu")[0]): + with run_on_cpu(): res_cpu = graph(*cpu_inputs) compare_tensor_to_golden(res, res_cpu, required_pcc, required_atol) From 42b8273bc6bc25fdc2c20c422d82c804b324318f Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Fri, 22 Nov 2024 14:25:22 +0000 Subject: [PATCH 7/7] Added an additional issue --- tests/TTIR/test_basic_ops.py | 8 ++++---- tests/infrastructure.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index 97885cea..e9d9e8df 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -225,13 +225,13 @@ def module_transpose(a): @pytest.mark.skip( - "Scalar ops currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73" + "Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73" ) -def test_shape_scalar(): - def module_shape_scalar(a): +def test_scalar_type(): + def module_scalar_type(a): return a.shape[0] - verify_module(module_shape_scalar, [(3, 3)]) + verify_module(module_scalar_type, [(3, 3)]) # Transpose op failing for higher ranks/dimensions. diff --git a/tests/infrastructure.py b/tests/infrastructure.py index ee654184..42c16c97 100644 --- a/tests/infrastructure.py +++ b/tests/infrastructure.py @@ -17,6 +17,7 @@ def run_on_cpu(): yield +# TODO(issue #80): Make this creation more explicit regarding the originating devices. def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32): device_cpu = jax.devices("cpu")[0] with jax.default_device(device_cpu): @@ -39,6 +40,7 @@ def compare_tensor_to_golden( ): ret = True + # TODO (issue #81): Remove these reshapes once the PJRT can handle scalars. if tensor.ndim == 0: tensor = tensor.reshape((1,)) if golden.ndim == 0: