@@ -407,7 +407,7 @@ def bilinear_interpolate(data, y, x, snap_border=False):
407407
408408
409409class TestRoIAlign (RoIOpTester ):
410- mps_backward_atol = 6e-2
410+ mps_backward_atol = 1e-1
411411
412412 def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , aligned = False , ** kwargs ):
413413 return ops .RoIAlign (
@@ -1244,6 +1244,20 @@ def xfail_if_mps(x):
12441244 return x
12451245
12461246
1247+ def skip_if_mps (x ):
1248+ mps_skip_param = pytest .param ("mps" , marks = (pytest .mark .needs_mps , pytest .mark .skip (reason = "Flaky on MPS" )))
1249+ new_pytestmark = []
1250+ for mark in x .pytestmark :
1251+ if isinstance (mark , pytest .Mark ) and mark .name == "parametrize" :
1252+ if mark .args [0 ] == "device" :
1253+ params = cpu_and_cuda () + (mps_skip_param ,)
1254+ new_pytestmark .append (pytest .mark .parametrize ("device" , params ))
1255+ continue
1256+ new_pytestmark .append (mark )
1257+ x .__dict__ ["pytestmark" ] = new_pytestmark
1258+ return x
1259+
1260+
12471261optests .generate_opcheck_tests (
12481262 testcase = TestDeformConv ,
12491263 namespaces = ["torchvision" ],
@@ -1252,6 +1266,7 @@ def xfail_if_mps(x):
12521266 additional_decorators = {
12531267 "test_aot_dispatch_dynamic__test_forward" : [xfail_if_mps ],
12541268 "test_autograd_registration__test_forward" : [xfail_if_mps ],
1269+ "test_faketensor__test_forward" : [skip_if_mps ],
12551270 },
12561271 test_utils = OPTESTS ,
12571272)
0 commit comments