File tree 1 file changed +8
-2
lines changed
1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -390,10 +390,16 @@ def test_threefry_gpu_kernel_lowering(self):
390
390
f = lambda key : jax .random .uniform (key , (1 ,))
391
391
with jax ._src .config .threefry_gpu_kernel_lowering (False ):
392
392
hlo_text = jax .jit (f ).lower (jax .random .key (17 )).as_text ()
393
- self .assertNotIn ("cu_threefry2x32" , hlo_text )
393
+ if jtu .is_device_rocm ():
394
+ self .assertNotIn ("hip_threefry2x32" , hlo_text )
395
+ else :
396
+ self .assertNotIn ("cu_threefry2x32" , hlo_text )
394
397
with jax ._src .config .threefry_gpu_kernel_lowering (True ):
395
398
hlo_text = jax .jit (f ).lower (jax .random .key (17 )).as_text ()
396
- self .assertIn ("cu_threefry2x32" , hlo_text )
399
+ if jtu .is_device_rocm ():
400
+ self .assertIn ("hip_threefry2x32" , hlo_text )
401
+ else :
402
+ self .assertIn ("cu_threefry2x32" , hlo_text )
397
403
398
404
@parameterized .parameters ([{'make_key' : ctor } for ctor in KEY_CTORS ])
399
405
def test_random_seed_offset (self , make_key ):
You can’t perform that action at this time.
0 commit comments