@@ -83,6 +83,8 @@ def testCantStopServerBeforeStartingServer(self):
83
83
jax .profiler .stop_server ()
84
84
85
85
def testProgrammaticProfiling (self ):
86
+ if jtu .is_device_rocm :
87
+ raise unittest .SkipTest ("[ROCm] test runs infinitely in ci." )
86
88
with tempfile .TemporaryDirectory () as tmpdir :
87
89
try :
88
90
jax .profiler .start_trace (tmpdir )
@@ -104,6 +106,8 @@ def testProgrammaticProfiling(self):
104
106
self .assertIn (b"pxla.py" , proto )
105
107
106
108
def testProfilerGetFDOProfile (self ):
109
+ if jtu .is_device_rocm :
110
+ raise unittest .SkipTest ("[ROCm] test runs infinitely in ci." )
107
111
# Tests stop_and_get_fod_profile could run.
108
112
try :
109
113
jax .profiler .start_trace ("test" )
@@ -116,6 +120,8 @@ def testProfilerGetFDOProfile(self):
116
120
self .assertIn (b"copy" , fdo_profile )
117
121
118
122
def testProgrammaticProfilingErrors (self ):
123
+ if jtu .is_device_rocm :
124
+ raise unittest .SkipTest ("[ROCm] test runs infinitely in ci." )
119
125
with self .assertRaisesRegex (RuntimeError , "No profile started" ):
120
126
jax .profiler .stop_trace ()
121
127
@@ -131,6 +137,8 @@ def testProgrammaticProfilingErrors(self):
131
137
jax .profiler .stop_trace ()
132
138
133
139
def testProgrammaticProfilingContextManager (self ):
140
+ if jtu .is_device_rocm :
141
+ raise unittest .SkipTest ("[ROCm] test runs infinitely in ci." )
134
142
with tempfile .TemporaryDirectory () as tmpdir :
135
143
with jax .profiler .trace (tmpdir ):
136
144
jax .pmap (lambda x : jax .lax .psum (x + 1 , 'i' ), axis_name = 'i' )(
0 commit comments