|
3 | 3 | import pytest
|
4 | 4 | import torch
|
5 | 5 | import torch.nn.functional as F
|
| 6 | +from torch.testing._internal.common_utils import run_tests |
6 | 7 |
|
7 | 8 | from torchao._models.llama.model import Transformer, prepare_inputs_for_model
|
8 | 9 | from torchao._models.llama.tokenizer import get_tokenizer
|
9 | 10 | from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
|
10 | 11 | from torchao.quantization.utils import _lm_eval_available
|
| 12 | +from torchao.utils import is_fbcode |
| 13 | + |
| 14 | +if is_fbcode(): |
| 15 | + pytest.skip("Skipping the test in fbcode due to missing model and tokenizer files") |
11 | 16 |
|
12 | 17 | if _lm_eval_available:
|
13 | 18 | hqq_core = pytest.importorskip("hqq.core", reason="requires hqq")
|
@@ -247,88 +252,92 @@ def run_eval(self, tasks, limit):
|
247 | 252 | return result
|
248 | 253 |
|
249 | 254 |
|
250 |
| -precision = torch.bfloat16 |
251 |
| -device = "cuda" |
252 |
| -print("Loading model") |
253 |
| -checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") |
254 |
| -model = Transformer.from_name(checkpoint_path.parent.name) |
255 |
| -checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) |
256 |
| -model.load_state_dict(checkpoint, assign=True) |
257 |
| -model = model.to(dtype=precision, device="cpu") |
258 |
| -model.eval() |
259 |
| -print("Model loaded") |
260 |
| -tokenizer_path = checkpoint_path.parent / "tokenizer.model" |
261 |
| -assert tokenizer_path.is_file(), tokenizer_path |
262 |
| -tokenizer = get_tokenizer( # pyre-ignore[28] |
263 |
| - tokenizer_path, |
264 |
| - "Llama-2-7b-chat-hf", |
265 |
| -) |
266 |
| -print("Tokenizer loaded") |
267 |
| - |
268 |
| - |
269 |
| -blocksize = 128 |
270 |
| -percdamp = 0.01 |
271 |
| -groupsize = 64 |
272 |
| -calibration_tasks = ["wikitext"] |
273 |
| -calibration_limit = None |
274 |
| -calibration_seq_length = 100 |
275 |
| -input_prep_func = prepare_inputs_for_model |
276 |
| -pad_calibration_inputs = False |
277 |
| -print("Recording inputs") |
278 |
| -inputs = ( |
279 |
| - InputRecorder( |
280 |
| - tokenizer, |
281 |
| - calibration_seq_length, |
282 |
| - input_prep_func, |
283 |
| - pad_calibration_inputs, |
284 |
| - model.config.vocab_size, |
285 |
| - device="cpu", |
| 255 | +def test_gptq_mt(): |
| 256 | + precision = torch.bfloat16 |
| 257 | + device = "cuda" |
| 258 | + print("Loading model") |
| 259 | + checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") |
| 260 | + model = Transformer.from_name(checkpoint_path.parent.name) |
| 261 | + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) |
| 262 | + model.load_state_dict(checkpoint, assign=True) |
| 263 | + model = model.to(dtype=precision, device="cpu") |
| 264 | + model.eval() |
| 265 | + print("Model loaded") |
| 266 | + tokenizer_path = checkpoint_path.parent / "tokenizer.model" |
| 267 | + assert tokenizer_path.is_file(), tokenizer_path |
| 268 | + tokenizer = get_tokenizer( # pyre-ignore[28] |
| 269 | + tokenizer_path, |
| 270 | + "Llama-2-7b-chat-hf", |
286 | 271 | )
|
287 |
| - .record_inputs( |
288 |
| - calibration_tasks, |
289 |
| - calibration_limit, |
| 272 | + print("Tokenizer loaded") |
| 273 | + |
| 274 | + blocksize = 128 |
| 275 | + percdamp = 0.01 |
| 276 | + groupsize = 64 |
| 277 | + calibration_tasks = ["wikitext"] |
| 278 | + calibration_limit = None |
| 279 | + calibration_seq_length = 100 |
| 280 | + input_prep_func = prepare_inputs_for_model |
| 281 | + pad_calibration_inputs = False |
| 282 | + print("Recording inputs") |
| 283 | + inputs = ( |
| 284 | + InputRecorder( |
| 285 | + tokenizer, |
| 286 | + calibration_seq_length, |
| 287 | + input_prep_func, |
| 288 | + pad_calibration_inputs, |
| 289 | + model.config.vocab_size, |
| 290 | + device="cpu", |
| 291 | + ) |
| 292 | + .record_inputs( |
| 293 | + calibration_tasks, |
| 294 | + calibration_limit, |
| 295 | + ) |
| 296 | + .get_inputs() |
290 | 297 | )
|
291 |
| - .get_inputs() |
292 |
| -) |
293 |
| -print("Inputs recorded") |
294 |
| -quantizer = Int4WeightOnlyGPTQQuantizer( |
295 |
| - blocksize, |
296 |
| - percdamp, |
297 |
| - groupsize, |
298 |
| -) |
299 |
| - |
300 |
| -model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) |
301 |
| -multi = [ |
302 |
| - MultiTensor([inp for inp, _ in inputs]), |
303 |
| - MultiTensor([inds for _, inds in inputs]), |
304 |
| -] |
305 |
| -print("Quantizing model") |
306 |
| -model = quantizer.quantize(model, multi).cuda() |
307 |
| -print("Model quantized") |
308 |
| -print("Saving model and fixing state dict") |
309 |
| -regular_state_dict = model.state_dict() # defaultdict(torch.tensor) |
310 |
| -for key, value in model.state_dict().items(): |
311 |
| - if isinstance(value, MultiTensor): |
312 |
| - regular_state_dict[key] = value.values[0] |
313 |
| - else: |
314 |
| - regular_state_dict[key] = value |
315 |
| - |
316 |
| -model = Transformer.from_name(checkpoint_path.parent.name) |
317 |
| -remove = [k for k in regular_state_dict if "kv_cache" in k] |
318 |
| -for k in remove: |
319 |
| - del regular_state_dict[k] |
320 |
| - |
321 |
| -model.load_state_dict(regular_state_dict, assign=True) |
322 |
| -torch.save(model.state_dict(), "model.pth") |
323 |
| -print("Running evaluation") |
324 |
| -result = TransformerEvalWrapper( |
325 |
| - model.to(device), # quantized model needs to run on cuda |
326 |
| - tokenizer, |
327 |
| - model.config.block_size, |
328 |
| - prepare_inputs_for_model, |
329 |
| -).run_eval( |
330 |
| - ["wikitext"], |
331 |
| - None, |
332 |
| -) |
| 298 | + print("Inputs recorded") |
| 299 | + quantizer = Int4WeightOnlyGPTQQuantizer( |
| 300 | + blocksize, |
| 301 | + percdamp, |
| 302 | + groupsize, |
| 303 | + ) |
| 304 | + |
| 305 | + model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) |
| 306 | + multi = [ |
| 307 | + MultiTensor([inp for inp, _ in inputs]), |
| 308 | + MultiTensor([inds for _, inds in inputs]), |
| 309 | + ] |
| 310 | + print("Quantizing model") |
| 311 | + model = quantizer.quantize(model, multi).cuda() |
| 312 | + print("Model quantized") |
| 313 | + print("Saving model and fixing state dict") |
| 314 | + regular_state_dict = model.state_dict() # defaultdict(torch.tensor) |
| 315 | + for key, value in model.state_dict().items(): |
| 316 | + if isinstance(value, MultiTensor): |
| 317 | + regular_state_dict[key] = value.values[0] |
| 318 | + else: |
| 319 | + regular_state_dict[key] = value |
| 320 | + |
| 321 | + model = Transformer.from_name(checkpoint_path.parent.name) |
| 322 | + remove = [k for k in regular_state_dict if "kv_cache" in k] |
| 323 | + for k in remove: |
| 324 | + del regular_state_dict[k] |
| 325 | + |
| 326 | + model.load_state_dict(regular_state_dict, assign=True) |
| 327 | + torch.save(model.state_dict(), "model.pth") |
| 328 | + print("Running evaluation") |
| 329 | + TransformerEvalWrapper( |
| 330 | + model.to(device), # quantized model needs to run on cuda |
| 331 | + tokenizer, |
| 332 | + model.config.block_size, |
| 333 | + prepare_inputs_for_model, |
| 334 | + ).run_eval( |
| 335 | + ["wikitext"], |
| 336 | + None, |
| 337 | + ) |
| 338 | + |
| 339 | + |
| 340 | +if __name__ == "__main__": |
| 341 | + run_tests() |
333 | 342 |
|
334 | 343 | # wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}
|
0 commit comments