Skip to content

Commit 0ab9785

Browse files
committed
[ROCm] Add hip specific checks in threefry test
1 parent ad701f6 commit 0ab9785

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/random_test.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,16 @@ def test_threefry_gpu_kernel_lowering(self):
390390
f = lambda key: jax.random.uniform(key, (1,))
391391
with jax._src.config.threefry_gpu_kernel_lowering(False):
392392
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
393-
self.assertNotIn("cu_threefry2x32", hlo_text)
393+
if jtu.is_device_rocm():
394+
self.assertNotIn("hip_threefry2x32", hlo_text)
395+
else:
396+
self.assertNotIn("cu_threefry2x32", hlo_text)
394397
with jax._src.config.threefry_gpu_kernel_lowering(True):
395398
hlo_text = jax.jit(f).lower(jax.random.key(17)).as_text()
396-
self.assertIn("cu_threefry2x32", hlo_text)
399+
if jtu.is_device_rocm():
400+
self.assertIn("hip_threefry2x32", hlo_text)
401+
else:
402+
self.assertIn("cu_threefry2x32", hlo_text)
397403

398404
@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
399405
def test_random_seed_offset(self, make_key):

0 commit comments

Comments
 (0)