Skip to content

Commit 467568e

Browse files
committed
Lint and fix qat/spin
1 parent d7cea25 commit 467568e

File tree

3 files changed

+30
-28
lines changed

3 files changed

+30
-28
lines changed

examples/models/llama/eval_llama_lib.py

-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from pytorch_tokenizers import get_tokenizer
2222
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
2323
from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken
24-
from torch.nn import CrossEntropyLoss
25-
from tqdm import tqdm
2624

2725
from .evaluate.eager_eval import EagerEvalWrapper
2826

examples/models/llama/export_llama_lib.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,7 @@ def _load_llama_model(
12511251
input_prune_map_path=input_prune_map_path,
12521252
output_prune_map_path=output_prune_map_path,
12531253
dtype=torch_dtype,
1254-
args=config,
1254+
config=config,
12551255
)
12561256
)
12571257

examples/models/llama/model.py

+29-25
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, **kwargs):
5353
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5454
self.max_seq_len = kwargs.get("max_seq_len", 128)
5555
self.max_context_len = kwargs.get("max_context_len", 128)
56-
self.args = kwargs.get("args", None)
56+
self.config = kwargs.get("config", None)
5757

5858
assert (
5959
self.max_context_len >= self.max_seq_len
@@ -156,10 +156,10 @@ def __init__(self, **kwargs):
156156

157157
if model_args.use_scaled_rope:
158158
# Older models don't have use_scaled_rope configuration
159-
assert self.args.model not in ["llama2", "stories110m"]
159+
assert self.config.model.name not in ["llama2", "stories110m"]
160160

161161
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
162-
if self.args.model not in ["llama3", "llama3_1"]:
162+
if self.config and self.config.model.name not in ["llama3", "llama3_1"]:
163163
model_args.rope_scale_factor = 32
164164

165165
if kwargs.get("verbose", False):
@@ -194,7 +194,7 @@ def __init__(self, **kwargs):
194194
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
195195
self.model_
196196
)
197-
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
197+
elif self.config and self.config.quantization.use_spin_quant:
198198
print("Using SPIN quantization.")
199199
self._transform_for_pre_quantization(checkpoint, model_args)
200200

@@ -203,19 +203,19 @@ def __init__(self, **kwargs):
203203
)
204204

205205
sanitize_checkpoint_from_pre_quantization(checkpoint)
206-
elif hasattr(self.args, "use_qat") and self.args.use_qat:
206+
elif self.config and self.config.quantization.use_qat:
207207
print("Using QAT quantization.")
208208
self._transform_for_pre_quantization(checkpoint, model_args)
209-
if hasattr(self.args, "use_lora") and self.args.use_lora:
210-
assert model_args.lora_args["rank"] == self.args.use_lora
209+
if self.config and self.config.quantization.use_lora:
210+
assert model_args.lora_args["rank"] == self.config.quantization.use_lora
211211
from .source_transformation.lora import (
212212
transform_linear_for_lora_after_quantization,
213213
)
214214

215215
self.model_ = transform_linear_for_lora_after_quantization(
216216
self.model_,
217217
checkpoint,
218-
self.args.use_lora,
218+
self.config.quantization.use_lora,
219219
)
220220

221221
from .source_transformation.pre_quantization import (
@@ -224,16 +224,16 @@ def __init__(self, **kwargs):
224224

225225
sanitize_checkpoint_from_pre_quantization(checkpoint)
226226

227-
if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
227+
if self.config and self.config.misc.use_attention_sink:
228228
from .source_transformation.attention_sink import enable_attention_sink
229229

230-
attention_sink_params = self.args.use_attention_sink.split(",")
230+
attention_sink_params = self.config.misc.use_attention_sink.split(",")
231231
assert len(attention_sink_params) == 3
232232
sink_size = int(attention_sink_params[0])
233233
window_size = int(attention_sink_params[1])
234234
eviction_batch_size = int(attention_sink_params[2])
235235

236-
assert self.args.max_context_length == sink_size + window_size
236+
assert self.config.sequence.max_context_length == sink_size + window_size
237237

238238
self.model_ = enable_attention_sink(
239239
module=self.model_,
@@ -321,20 +321,24 @@ def get_example_inputs_kvcache_sdpa(self):
321321
)
322322

323323
def _transform_for_pre_quantization(self, checkpoint, model_args):
324-
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
325-
assert self.args.preq_mode in [
324+
assert self.config
325+
assert self.config.quantization.preq_mode, "preq_mode must be specified"
326+
assert self.config.quantization.preq_mode in [
326327
"8da4w",
327328
"8da4w_output_8da8w",
328-
], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant."
329-
assert hasattr(
330-
self.args, "preq_group_size"
329+
], f"Quantization mode {self.config.quantization.preq_mode} is not compatible with SpinQuant."
330+
assert (
331+
self.config.quantization.preq_group_size
331332
), "preq_group_size must be specified"
332-
assert hasattr(self.args, "dtype_override"), "dtype_override must be specified"
333+
assert self.config.model.dtype_override, "dtype_override must be specified"
333334
from .source_transformation.pre_quantization import (
334335
transform_linear_for_pre_quantization,
335336
)
336337

337-
assert self.args.preq_group_size == model_args.quantization_args["group_size"]
338+
assert (
339+
self.config.quantization.preq_group_size
340+
== model_args.quantization_args["group_size"]
341+
)
338342

339343
mapping = {
340344
"fp32": torch.float32,
@@ -343,28 +347,28 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
343347
}
344348

345349
# Transform the output layer first if needed.
346-
if self.args.preq_mode == "8da4w_output_8da8w":
350+
if self.config.quantization.preq_mode == "8da4w_output_8da8w":
347351
from .source_transformation.pre_quantization import (
348352
transform_output_linear_for_pre_quantization,
349353
)
350354

351355
self.model_ = transform_output_linear_for_pre_quantization(
352356
module=self.model_,
353357
checkpoint=checkpoint,
354-
dtype=mapping[self.args.dtype_override],
358+
dtype=mapping[self.config.model.dtype_override],
355359
)
356360

357361
self.model_ = transform_linear_for_pre_quantization(
358362
self.model_,
359363
checkpoint,
360-
self.args.preq_group_size,
361-
mapping[self.args.dtype_override],
364+
self.config.quantization.preq_group_size,
365+
mapping[self.config.model.dtype_override],
362366
)
363367

364368
embedding_bit_width, embedding_group_size = None, None
365-
if hasattr(self.args, "preq_embedding_quantize"):
369+
if self.config.quantization.preq_embedding_quantize:
366370
embedding_bit_width, embedding_group_size = (
367-
self.args.preq_embedding_quantize.split(",")
371+
self.config.quantization.preq_embedding_quantize.split(",")
368372
)
369373
from .source_transformation.pre_quantization import (
370374
transform_embedding_for_pre_quantization,
@@ -382,7 +386,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
382386
self.model_ = transform_embedding_for_pre_quantization(
383387
self.model_,
384388
checkpoint,
385-
mapping[self.args.dtype_override],
389+
mapping[self.config.model.dtype_override],
386390
int(embedding_bit_width),
387391
embedding_group_size,
388392
)

0 commit comments

Comments
 (0)