Skip to content

Commit 15888ca

Browse files
authored
Support pre-quantization via torchao quantize_ (#10293)
Checkpoints saved with torchao quantized subclasses can be loaded with the PR
1 parent 2a2f958 commit 15888ca

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

examples/models/llama/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.examples.models.llama.llama_transformer import Transformer
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from torchao.utils import TorchAOBaseTensor
2122

2223
try:
2324
from .fairseq2 import convert_to_llama_checkpoint
@@ -257,6 +258,9 @@ def __init__(self, **kwargs):
257258
strict=False,
258259
assign=True,
259260
) # self.model_ = Transformer(gptconf)
261+
for param in self.model_.parameters():
262+
if isinstance(param, TorchAOBaseTensor):
263+
param.requires_grad = False
260264
else:
261265
print("Checkpoint not provided, defaulting weights to zeros.")
262266
self.model_.to_empty(device="cpu")

extension/llm/export/builder.py

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
4242
from torch.export import export_for_training, ExportedProgram
4343
from torch.nn.attention import SDPBackend
44+
from torchao.utils import unwrap_tensor_subclass
4445

4546
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4647
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -199,6 +200,11 @@ def _get_edge_config(self) -> EdgeCompileConfig:
199200
return edge_config
200201

201202
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
203+
if module is not None:
204+
unwrap_tensor_subclass(module)
205+
else:
206+
unwrap_tensor_subclass(self.model)
207+
202208
dynamic_shape = self._get_dynamic_shape()
203209
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204210
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)

0 commit comments

Comments
 (0)