@@ -53,7 +53,7 @@ def __init__(self, **kwargs):
53
53
self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
54
54
self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
55
55
self .max_context_len = kwargs .get ("max_context_len" , 128 )
56
- self .args = kwargs .get ("args " , None )
56
+ self .config = kwargs .get ("config " , None )
57
57
58
58
assert (
59
59
self .max_context_len >= self .max_seq_len
@@ -156,10 +156,10 @@ def __init__(self, **kwargs):
156
156
157
157
if model_args .use_scaled_rope :
158
158
# Older models don't have use_scaled_rope configuration
159
- assert self .args .model not in ["llama2" , "stories110m" ]
159
+ assert self .config .model . name not in ["llama2" , "stories110m" ]
160
160
161
161
# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
162
- if self .args . model not in ["llama3" , "llama3_1" ]:
162
+ if self .config and self . config . model . name not in ["llama3" , "llama3_1" ]:
163
163
model_args .rope_scale_factor = 32
164
164
165
165
if kwargs .get ("verbose" , False ):
@@ -194,7 +194,7 @@ def __init__(self, **kwargs):
194
194
self .model_ = Int8DynActInt4WeightQuantizer ()._convert_for_runtime (
195
195
self .model_
196
196
)
197
- elif hasattr ( self .args , "use_spin_quant" ) and self .args .use_spin_quant :
197
+ elif self .config and self .config . quantization .use_spin_quant :
198
198
print ("Using SPIN quantization." )
199
199
self ._transform_for_pre_quantization (checkpoint , model_args )
200
200
@@ -203,19 +203,19 @@ def __init__(self, **kwargs):
203
203
)
204
204
205
205
sanitize_checkpoint_from_pre_quantization (checkpoint )
206
- elif hasattr ( self .args , "use_qat" ) and self .args .use_qat :
206
+ elif self .config and self .config . quantization .use_qat :
207
207
print ("Using QAT quantization." )
208
208
self ._transform_for_pre_quantization (checkpoint , model_args )
209
- if hasattr ( self .args , "use_lora" ) and self .args .use_lora :
210
- assert model_args .lora_args ["rank" ] == self .args .use_lora
209
+ if self .config and self .config . quantization .use_lora :
210
+ assert model_args .lora_args ["rank" ] == self .config . quantization .use_lora
211
211
from .source_transformation .lora import (
212
212
transform_linear_for_lora_after_quantization ,
213
213
)
214
214
215
215
self .model_ = transform_linear_for_lora_after_quantization (
216
216
self .model_ ,
217
217
checkpoint ,
218
- self .args .use_lora ,
218
+ self .config . quantization .use_lora ,
219
219
)
220
220
221
221
from .source_transformation .pre_quantization import (
@@ -224,16 +224,16 @@ def __init__(self, **kwargs):
224
224
225
225
sanitize_checkpoint_from_pre_quantization (checkpoint )
226
226
227
- if hasattr ( self .args , "use_attention_sink" ) and self .args .use_attention_sink :
227
+ if self .config and self .config . misc .use_attention_sink :
228
228
from .source_transformation .attention_sink import enable_attention_sink
229
229
230
- attention_sink_params = self .args .use_attention_sink .split ("," )
230
+ attention_sink_params = self .config . misc .use_attention_sink .split ("," )
231
231
assert len (attention_sink_params ) == 3
232
232
sink_size = int (attention_sink_params [0 ])
233
233
window_size = int (attention_sink_params [1 ])
234
234
eviction_batch_size = int (attention_sink_params [2 ])
235
235
236
- assert self .args .max_context_length == sink_size + window_size
236
+ assert self .config . sequence .max_context_length == sink_size + window_size
237
237
238
238
self .model_ = enable_attention_sink (
239
239
module = self .model_ ,
@@ -321,20 +321,24 @@ def get_example_inputs_kvcache_sdpa(self):
321
321
)
322
322
323
323
def _transform_for_pre_quantization (self , checkpoint , model_args ):
324
- assert hasattr (self .args , "preq_mode" ), "preq_mode must be specified"
325
- assert self .args .preq_mode in [
324
+ assert self .config
325
+ assert self .config .quantization .preq_mode , "preq_mode must be specified"
326
+ assert self .config .quantization .preq_mode in [
326
327
"8da4w" ,
327
328
"8da4w_output_8da8w" ,
328
- ], f"Quantization mode { self .args .preq_mode } is not compatible with SpinQuant."
329
- assert hasattr (
330
- self .args , " preq_group_size"
329
+ ], f"Quantization mode { self .config . quantization .preq_mode } is not compatible with SpinQuant."
330
+ assert (
331
+ self .config . quantization . preq_group_size
331
332
), "preq_group_size must be specified"
332
- assert hasattr ( self .args , " dtype_override" ) , "dtype_override must be specified"
333
+ assert self .config . model . dtype_override , "dtype_override must be specified"
333
334
from .source_transformation .pre_quantization import (
334
335
transform_linear_for_pre_quantization ,
335
336
)
336
337
337
- assert self .args .preq_group_size == model_args .quantization_args ["group_size" ]
338
+ assert (
339
+ self .config .quantization .preq_group_size
340
+ == model_args .quantization_args ["group_size" ]
341
+ )
338
342
339
343
mapping = {
340
344
"fp32" : torch .float32 ,
@@ -343,28 +347,28 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
343
347
}
344
348
345
349
# Transform the output layer first if needed.
346
- if self .args .preq_mode == "8da4w_output_8da8w" :
350
+ if self .config . quantization .preq_mode == "8da4w_output_8da8w" :
347
351
from .source_transformation .pre_quantization import (
348
352
transform_output_linear_for_pre_quantization ,
349
353
)
350
354
351
355
self .model_ = transform_output_linear_for_pre_quantization (
352
356
module = self .model_ ,
353
357
checkpoint = checkpoint ,
354
- dtype = mapping [self .args .dtype_override ],
358
+ dtype = mapping [self .config . model .dtype_override ],
355
359
)
356
360
357
361
self .model_ = transform_linear_for_pre_quantization (
358
362
self .model_ ,
359
363
checkpoint ,
360
- self .args .preq_group_size ,
361
- mapping [self .args .dtype_override ],
364
+ self .config . quantization .preq_group_size ,
365
+ mapping [self .config . model .dtype_override ],
362
366
)
363
367
364
368
embedding_bit_width , embedding_group_size = None , None
365
- if hasattr ( self .args , " preq_embedding_quantize" ) :
369
+ if self .config . quantization . preq_embedding_quantize :
366
370
embedding_bit_width , embedding_group_size = (
367
- self .args .preq_embedding_quantize .split ("," )
371
+ self .config . quantization .preq_embedding_quantize .split ("," )
368
372
)
369
373
from .source_transformation .pre_quantization import (
370
374
transform_embedding_for_pre_quantization ,
@@ -382,7 +386,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
382
386
self .model_ = transform_embedding_for_pre_quantization (
383
387
self .model_ ,
384
388
checkpoint ,
385
- mapping [self .args .dtype_override ],
389
+ mapping [self .config . model .dtype_override ],
386
390
int (embedding_bit_width ),
387
391
embedding_group_size ,
388
392
)
0 commit comments