diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 8b6370a5cb..ead45cb8f4 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -26,6 +26,16 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +# source: https://stackoverflow.com/a/22638709 +@pytest.fixture(autouse=True) +def run_around_tests(): + # 1. before test - set up (currently do nothing) + # 2. run test + yield + # 3. after test - teardown + torch._dynamo.reset() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("bias", [True, False])