Skip to content

Commit 42b8273

Browse files
Added an additional issue
1 parent 52e5c3b commit 42b8273

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

tests/TTIR/test_basic_ops.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,13 @@ def module_transpose(a):
225225

226226

227227
@pytest.mark.skip(
228-
"Scalar ops currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73"
228+
"Scalars currently not working due to issue https://github.com/tenstorrent/tt-xla/issues/73"
229229
)
230-
def test_shape_scalar():
231-
def module_shape_scalar(a):
230+
def test_scalar_type():
231+
def module_scalar_type(a):
232232
return a.shape[0]
233233

234-
verify_module(module_shape_scalar, [(3, 3)])
234+
verify_module(module_scalar_type, [(3, 3)])
235235

236236

237237
# Transpose op failing for higher ranks/dimensions.

tests/infrastructure.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def run_on_cpu():
1717
yield
1818

1919

20+
# TODO(issue #80): Make this creation more explicit regarding the originating devices.
2021
def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32):
2122
device_cpu = jax.devices("cpu")[0]
2223
with jax.default_device(device_cpu):
@@ -39,6 +40,7 @@ def compare_tensor_to_golden(
3940
):
4041
ret = True
4142

43+
# TODO (issue #81): Remove these reshapes once the PJRT can handle scalars.
4244
if tensor.ndim == 0:
4345
tensor = tensor.reshape((1,))
4446
if golden.ndim == 0:

0 commit comments

Comments
 (0)