Skip to content

Commit 5e4d382

Browse files
jerryzh168amdfaa
authored andcommitted
Skip calling unwrap_tensor_subclass for torch 2.7+ (#1531)
* Skip calling unwrap_tensor_subclass for torch 2.7+ Summary: att Test Plan: python test/integration/test_integration.py -k TestExport AOTI test also works, but there is some AOT compile issues, reproducible by python test/integration/test_integration.py -k TestAOTI.test_aoti_06 Reviewers: Subscribers: Tasks: Tags: * fix
1 parent eb180cd commit 5e4d382

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

test/integration/test_integration.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
TORCH_VERSION_AT_LEAST_2_4,
7777
TORCH_VERSION_AT_LEAST_2_5,
7878
TORCH_VERSION_AT_LEAST_2_6,
79+
TORCH_VERSION_AT_LEAST_2_7,
7980
benchmark_model,
8081
is_fbcode,
8182
is_sm_at_least_90,
@@ -1749,7 +1750,10 @@ def test_autoquant_min_sqnr(self, device, dtype):
17491750

17501751
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
17511752
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
1752-
@unittest.skip("AOTI tests are failing right now")
1753+
@unittest.skip(
1754+
"AOTI tests are failing right now, repro by commenting out the skip and run:"
1755+
"python test/integration/test_integration.py -k TestAOTI.test_aoti_06"
1756+
)
17531757
class TestAOTI(unittest.TestCase):
17541758
@parameterized.expand(
17551759
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
@@ -1792,7 +1796,8 @@ def forward(self, x):
17921796
model(x)
17931797

17941798
api(model)
1795-
unwrap_tensor_subclass(model)
1799+
if not TORCH_VERSION_AT_LEAST_2_7:
1800+
unwrap_tensor_subclass(model)
17961801

17971802
# running model
17981803
model(x)
@@ -1802,7 +1807,7 @@ def forward(self, x):
18021807

18031808
example_inputs = (x,)
18041809
torch._inductor.aoti_compile_and_package(
1805-
torch.export.export(model, example_inputs, strict=True), example_inputs
1810+
torch.export.export(model, example_inputs, strict=True)
18061811
)
18071812

18081813

@@ -1851,7 +1856,8 @@ def forward(self, x):
18511856
model(x)
18521857

18531858
api(model)
1854-
unwrap_tensor_subclass(model)
1859+
if not TORCH_VERSION_AT_LEAST_2_7:
1860+
unwrap_tensor_subclass(model)
18551861

18561862
# running model
18571863
ref = model(x)

torchao/quantization/README.md

+4-7
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ for module, name in model.named_modules():
296296
module.weight = nn.Parameter(to_linear_activation_quantized(module.weight, input_quant_func))
297297
```
298298

299-
#### Workaround with `unwrap_tensor_subclass` for `export`, `AOTI` and `torch.compile` (pytorch 2.4 and before only)
300-
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
301-
`torch.export.export` and `torch.aot_compile` with the following workaround:
299+
#### Workaround with `unwrap_tensor_subclass` for `export`, `AOTI` and `torch.compile`
300+
301+
If you are using pytorch 2.6 or before, you need to call `unwrap_tensor_subclass` before `torch.export.export` and `aot_compile`:
302302
```
303303
from torchao.utils import unwrap_tensor_subclass
304304
m_unwrapped = unwrap_tensor_subclass(m)
@@ -311,10 +311,7 @@ m = torch.export.export(m_unwrapped, example_inputs).module()
311311
torch._export.aot_compile(m_unwrapped, example_inputs)
312312
```
313313

314-
For `torch.compile`, if you are using pytorch nightly or pytorch 2.5+, you won't need to use `unwrap_tensor_subclass` in order to be compatible with `torch.compile`,
315-
but if you use 2.4 or before, you'll need to use `unwrap_tensor_subclass` as well to be able to run `torch.compile` on the quantized model.
316-
317-
Note that the workaround will not be needed after https://github.com/pytorch/pytorch/issues/129682 is fixed.
314+
If you are using pytorch 2.4 or before, you'll also need `unwrap_tensor_subclass` before calling `torch.compile` as well.
318315

319316
Note that the workaround is also required for `torch.compile` with `freezing` (`torch._inductor.config.freezing=True`) until https://github.com/pytorch/pytorch/pull/136265 is fixed.
320317

torchao/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"TORCH_VERSION_AT_LEAST_2_4",
2828
"TORCH_VERSION_AT_LEAST_2_5",
2929
"TORCH_VERSION_AT_LEAST_2_6",
30+
"TORCH_VERSION_AT_LEAST_2_7",
3031
# Needs to be deprecated in the future
3132
"TORCH_VERSION_AFTER_2_2",
3233
"TORCH_VERSION_AFTER_2_3",
@@ -367,6 +368,7 @@ def torch_version_at_least(min_version):
367368
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0
368369

369370

371+
TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0")
370372
TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
371373
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")
372374
TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0")

0 commit comments

Comments
 (0)