Skip to content

Commit ad701f6

Browse files
committed
[ROCm] skip failing profiling tests
1 parent 5338c4c commit ad701f6

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/profiler_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def testCantStopServerBeforeStartingServer(self):
8383
jax.profiler.stop_server()
8484

8585
def testProgrammaticProfiling(self):
86+
if jtu.is_device_rocm:
87+
raise unittest.SkipTest("[ROCm] test runs infinitely in ci.")
8688
with tempfile.TemporaryDirectory() as tmpdir:
8789
try:
8890
jax.profiler.start_trace(tmpdir)
@@ -104,6 +106,8 @@ def testProgrammaticProfiling(self):
104106
self.assertIn(b"pxla.py", proto)
105107

106108
def testProfilerGetFDOProfile(self):
109+
if jtu.is_device_rocm:
110+
raise unittest.SkipTest("[ROCm] test runs infinitely in ci.")
107111
# Tests stop_and_get_fod_profile could run.
108112
try:
109113
jax.profiler.start_trace("test")
@@ -116,6 +120,8 @@ def testProfilerGetFDOProfile(self):
116120
self.assertIn(b"copy", fdo_profile)
117121

118122
def testProgrammaticProfilingErrors(self):
123+
if jtu.is_device_rocm:
124+
raise unittest.SkipTest("[ROCm] test runs infinitely in ci.")
119125
with self.assertRaisesRegex(RuntimeError, "No profile started"):
120126
jax.profiler.stop_trace()
121127

@@ -131,6 +137,8 @@ def testProgrammaticProfilingErrors(self):
131137
jax.profiler.stop_trace()
132138

133139
def testProgrammaticProfilingContextManager(self):
140+
if jtu.is_device_rocm:
141+
raise unittest.SkipTest("[ROCm] test runs infinitely in ci.")
134142
with tempfile.TemporaryDirectory() as tmpdir:
135143
with jax.profiler.trace(tmpdir):
136144
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(

0 commit comments

Comments
 (0)