Skip to content

Commit 6013460

Browse files
jainapurvaamdfaa
authored andcommitted
Skip tests on fbcode
Differential Revision: D67982501 Pull Request resolved: #1532
1 parent 871ab30 commit 6013460

File tree

1 file changed

+90
-81
lines changed

1 file changed

+90
-81
lines changed

test/quantization/test_gptq_mt.py

+90-81
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import pytest
44
import torch
55
import torch.nn.functional as F
6+
from torch.testing._internal.common_utils import run_tests
67

78
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
89
from torchao._models.llama.tokenizer import get_tokenizer
910
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
1011
from torchao.quantization.utils import _lm_eval_available
12+
from torchao.utils import is_fbcode
13+
14+
if is_fbcode():
15+
pytest.skip("Skipping the test in fbcode due to missing model and tokenizer files")
1116

1217
if _lm_eval_available:
1318
hqq_core = pytest.importorskip("hqq.core", reason="requires hqq")
@@ -247,88 +252,92 @@ def run_eval(self, tasks, limit):
247252
return result
248253

249254

250-
precision = torch.bfloat16
251-
device = "cuda"
252-
print("Loading model")
253-
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
254-
model = Transformer.from_name(checkpoint_path.parent.name)
255-
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
256-
model.load_state_dict(checkpoint, assign=True)
257-
model = model.to(dtype=precision, device="cpu")
258-
model.eval()
259-
print("Model loaded")
260-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
261-
assert tokenizer_path.is_file(), tokenizer_path
262-
tokenizer = get_tokenizer( # pyre-ignore[28]
263-
tokenizer_path,
264-
"Llama-2-7b-chat-hf",
265-
)
266-
print("Tokenizer loaded")
267-
268-
269-
blocksize = 128
270-
percdamp = 0.01
271-
groupsize = 64
272-
calibration_tasks = ["wikitext"]
273-
calibration_limit = None
274-
calibration_seq_length = 100
275-
input_prep_func = prepare_inputs_for_model
276-
pad_calibration_inputs = False
277-
print("Recording inputs")
278-
inputs = (
279-
InputRecorder(
280-
tokenizer,
281-
calibration_seq_length,
282-
input_prep_func,
283-
pad_calibration_inputs,
284-
model.config.vocab_size,
285-
device="cpu",
255+
def test_gptq_mt():
256+
precision = torch.bfloat16
257+
device = "cuda"
258+
print("Loading model")
259+
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
260+
model = Transformer.from_name(checkpoint_path.parent.name)
261+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
262+
model.load_state_dict(checkpoint, assign=True)
263+
model = model.to(dtype=precision, device="cpu")
264+
model.eval()
265+
print("Model loaded")
266+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
267+
assert tokenizer_path.is_file(), tokenizer_path
268+
tokenizer = get_tokenizer( # pyre-ignore[28]
269+
tokenizer_path,
270+
"Llama-2-7b-chat-hf",
286271
)
287-
.record_inputs(
288-
calibration_tasks,
289-
calibration_limit,
272+
print("Tokenizer loaded")
273+
274+
blocksize = 128
275+
percdamp = 0.01
276+
groupsize = 64
277+
calibration_tasks = ["wikitext"]
278+
calibration_limit = None
279+
calibration_seq_length = 100
280+
input_prep_func = prepare_inputs_for_model
281+
pad_calibration_inputs = False
282+
print("Recording inputs")
283+
inputs = (
284+
InputRecorder(
285+
tokenizer,
286+
calibration_seq_length,
287+
input_prep_func,
288+
pad_calibration_inputs,
289+
model.config.vocab_size,
290+
device="cpu",
291+
)
292+
.record_inputs(
293+
calibration_tasks,
294+
calibration_limit,
295+
)
296+
.get_inputs()
290297
)
291-
.get_inputs()
292-
)
293-
print("Inputs recorded")
294-
quantizer = Int4WeightOnlyGPTQQuantizer(
295-
blocksize,
296-
percdamp,
297-
groupsize,
298-
)
299-
300-
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
301-
multi = [
302-
MultiTensor([inp for inp, _ in inputs]),
303-
MultiTensor([inds for _, inds in inputs]),
304-
]
305-
print("Quantizing model")
306-
model = quantizer.quantize(model, multi).cuda()
307-
print("Model quantized")
308-
print("Saving model and fixing state dict")
309-
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
310-
for key, value in model.state_dict().items():
311-
if isinstance(value, MultiTensor):
312-
regular_state_dict[key] = value.values[0]
313-
else:
314-
regular_state_dict[key] = value
315-
316-
model = Transformer.from_name(checkpoint_path.parent.name)
317-
remove = [k for k in regular_state_dict if "kv_cache" in k]
318-
for k in remove:
319-
del regular_state_dict[k]
320-
321-
model.load_state_dict(regular_state_dict, assign=True)
322-
torch.save(model.state_dict(), "model.pth")
323-
print("Running evaluation")
324-
result = TransformerEvalWrapper(
325-
model.to(device), # quantized model needs to run on cuda
326-
tokenizer,
327-
model.config.block_size,
328-
prepare_inputs_for_model,
329-
).run_eval(
330-
["wikitext"],
331-
None,
332-
)
298+
print("Inputs recorded")
299+
quantizer = Int4WeightOnlyGPTQQuantizer(
300+
blocksize,
301+
percdamp,
302+
groupsize,
303+
)
304+
305+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
306+
multi = [
307+
MultiTensor([inp for inp, _ in inputs]),
308+
MultiTensor([inds for _, inds in inputs]),
309+
]
310+
print("Quantizing model")
311+
model = quantizer.quantize(model, multi).cuda()
312+
print("Model quantized")
313+
print("Saving model and fixing state dict")
314+
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
315+
for key, value in model.state_dict().items():
316+
if isinstance(value, MultiTensor):
317+
regular_state_dict[key] = value.values[0]
318+
else:
319+
regular_state_dict[key] = value
320+
321+
model = Transformer.from_name(checkpoint_path.parent.name)
322+
remove = [k for k in regular_state_dict if "kv_cache" in k]
323+
for k in remove:
324+
del regular_state_dict[k]
325+
326+
model.load_state_dict(regular_state_dict, assign=True)
327+
torch.save(model.state_dict(), "model.pth")
328+
print("Running evaluation")
329+
TransformerEvalWrapper(
330+
model.to(device), # quantized model needs to run on cuda
331+
tokenizer,
332+
model.config.block_size,
333+
prepare_inputs_for_model,
334+
).run_eval(
335+
["wikitext"],
336+
None,
337+
)
338+
339+
340+
if __name__ == "__main__":
341+
run_tests()
333342

334343
# wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

0 commit comments

Comments
 (0)