File tree 2 files changed +10
-0
lines changed
2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change 18
18
from executorch .examples .models .llama .llama_transformer import Transformer
19
19
20
20
from executorch .examples .models .llama .model_args import ModelArgs
21
+ from torchao .utils import TorchAOBaseTensor
21
22
22
23
try :
23
24
from .fairseq2 import convert_to_llama_checkpoint
@@ -257,6 +258,9 @@ def __init__(self, **kwargs):
257
258
strict = False ,
258
259
assign = True ,
259
260
) # self.model_ = Transformer(gptconf)
261
+ for param in self .model_ .parameters ():
262
+ if isinstance (param , TorchAOBaseTensor ):
263
+ param .requires_grad = False
260
264
else :
261
265
print ("Checkpoint not provided, defaulting weights to zeros." )
262
266
self .model_ .to_empty (device = "cpu" )
Original file line number Diff line number Diff line change 41
41
from torch .ao .quantization .quantizer .composable_quantizer import ComposableQuantizer
42
42
from torch .export import export_for_training , ExportedProgram
43
43
from torch .nn .attention import SDPBackend
44
+ from torchao .utils import unwrap_tensor_subclass
44
45
45
46
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
46
47
logging .basicConfig (level = logging .INFO , format = FORMAT )
@@ -199,6 +200,11 @@ def _get_edge_config(self) -> EdgeCompileConfig:
199
200
return edge_config
200
201
201
202
def _export (self , module : Optional [torch .nn .Module ] = None ) -> ExportedProgram :
203
+ if module is not None :
204
+ unwrap_tensor_subclass (module )
205
+ else :
206
+ unwrap_tensor_subclass (self .model )
207
+
202
208
dynamic_shape = self ._get_dynamic_shape ()
203
209
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
204
210
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
You can’t perform that action at this time.
0 commit comments