Skip to content

Commit 52e5c3b

Browse files
Addressed comments
1 parent bdb8276 commit 52e5c3b

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

tests/infrastructure.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,21 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from contextlib import contextmanager
56
import jax
67
import jax.numpy as jnp
78

89

10+
@contextmanager
11+
def run_on_cpu():
12+
devices = jax.local_devices(backend="cpu")
13+
assert len(devices) > 0
14+
cpu = devices[0]
15+
16+
with jax.default_device(cpu):
17+
yield
18+
19+
920
def random_input_tensor(shape, key=42, on_device=False, dtype=jnp.float32):
1021
device_cpu = jax.devices("cpu")[0]
1122
with jax.default_device(device_cpu):
@@ -31,7 +42,7 @@ def compare_tensor_to_golden(
3142
if tensor.ndim == 0:
3243
tensor = tensor.reshape((1,))
3344
if golden.ndim == 0:
34-
with jax.default_device(jax.local_devices(backend="cpu")[0]):
45+
with run_on_cpu():
3546
golden = golden.reshape((1,))
3647

3748
if tensor.device != golden.device:
@@ -73,6 +84,6 @@ def verify_module(
7384
tt_inputs = [jax.device_put(cpu_input, tt_device) for cpu_input in cpu_inputs]
7485
graph = jax.jit(module)
7586
res = graph(*tt_inputs)
76-
with jax.default_device(jax.local_devices(backend="cpu")[0]):
87+
with run_on_cpu():
7788
res_cpu = graph(*cpu_inputs)
7889
compare_tensor_to_golden(res, res_cpu, required_pcc, required_atol)

0 commit comments

Comments
 (0)