diff --git a/quant_recipe.py b/quant_recipe.py new file mode 100644 index 0000000000..e7cd83bd7d --- /dev/null +++ b/quant_recipe.py @@ -0,0 +1,126 @@ +# Quantization +import lm_eval +import torch +from lm_eval import evaluator +from lm_eval.utils import ( + make_table, +) +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) + +model_id = "microsoft/Phi-4-mini-instruct" +ENABLE_GPTQ = False +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + MappingType, + ZeroPointDomain, + quantize_, +) + +model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype="auto", device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id, padding_side="right") + +embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, # torch.int8 + granularity=PerGroup(32), #PerAxis(0), + mapping_type=MappingType.ASYMMETRIC, + zero_point_domain=ZeroPointDomain.INT, + # scale_dtype=torch.float32, +) +linear_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=PerGroup(128), + weight_mapping_type=MappingType.SYMMETRIC, + weight_zero_point_domain=ZeroPointDomain.NONE, + # weight_scale_dtype=torch.bfloat16, +) + +quantize_(model, embedding_config, lambda m, fqn: isinstance(m, torch.nn.Embedding)) + +if not ENABLE_GPTQ: + quantize_( + model, + linear_config, + ) +else: + # This needs work + assert False, "GPTQ is not set up yet" + calibration_tasks = ["hellaswag"] + calibration_seq_length = 80 + calibration_limit = 1000 + pad_calibration_inputs = True + from torchao._models._eval import QwenMultiTensorInputRecorder + from torchao._models.llama.model import prepare_inputs_for_model + from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer + + # Assert Int4WeightOnlyConfig compatibility + # TODO: check if this is identical to Int4WeightOnlyGPTQQuantizer (other than dynamic activation bit) + assert linear_config.weight_dtype == torch.int4 + assert isinstance(linear_config.weight_granularity, PerGroup) + groupsize = linear_config.granularity.group_size + assert linear_config.weight_zero_point_domain == ZeroPointDomain.NONE + assert linear_config.weight_mapping_type == MappingType.ASYMMETRIC + + device = "cuda" + assert groupsize in [ + 32, + 64, + 128, + 256, + ], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" + inputs = ( + QwenMultiTensorInputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) + quantizer = Int4WeightOnlyGPTQQuantizer(group_size=groupsize, device=device) + model = quantizer.quantize(model, kw_inputs=inputs).to(device) + +# prompt = "Hey, are you conscious? Can you talk to me?" +# inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + +messages = [ + { + "role": "system", + "content": "You are a medieval knight and must provide explanations to modern people.", + }, + {"role": "user", "content": "How should I explain the Internet?"}, +] +inputs = tokenizer( + tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ), + return_tensors="pt", +).to("cuda") +generated_ids = model.generate(**inputs, max_new_tokens=128) +output_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False +) +print(output_text[0]) + +# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md +lm_eval_model = lm_eval.models.huggingface.HFLM(pretrained=model, batch_size=64) +results = evaluator.simple_evaluate( + lm_eval_model, tasks=["hellaswag"], device="cuda", batch_size="auto" +) +print(make_table(results)) diff --git a/torchao/dtypes/uintx/q_dq_layout.py b/torchao/dtypes/uintx/q_dq_layout.py index 1d5b2048b0..530ab8a535 100644 --- a/torchao/dtypes/uintx/q_dq_layout.py +++ b/torchao/dtypes/uintx/q_dq_layout.py @@ -24,19 +24,199 @@ logger.addHandler(handler) -from torchao.dtypes.utils import PlainLayout +from dataclasses import dataclass +from typing import Optional, Tuple +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.utils import fill_defaults + +aten = torch.ops.aten -class QDQLayout(PlainLayout): + +@dataclass(frozen=True) +class QDQLayout(Layout): pass -from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +def _same_metadata(self: "QDQTensorImpl", src: "QDQTensorImpl") -> bool: + return ( + isinstance(self, QDQTensorImpl) + and isinstance(src, QDQTensorImpl) + and self.shape == src.shape + and self.int_data.shape == src.int_data.shape + and self.scale.shape == src.scale.shape + and (self.zero_point is None and src.zero_point is None) + or ( + self.zero_point is not None + and src.zero_point is not None + and self.zero_point.shape == src.zero_point.shape + ) + and type(self._layout) == type(src._layout) + ) @register_layout(QDQLayout) -class _Impl(PlainAQTTensorImpl): - pass +class QDQTensorImpl(AQTTensorImpl): + """ + TensorImpl for QDQLayout layout for affine quantized tensor, it stores int_data, scale, zero_point + tensors directly as plain tensors. + + fields: + int_data (torch.Tensor): the quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + if self.zero_point is None: + return ["int_data", "scale"], [self._layout] + return ["int_data", "scale", "zero_point"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = ( + tensor_data_dict["int_data"], + tensor_data_dict["scale"], + tensor_data_dict.get("zero_point", None), + ) + (_layout,) = tensor_attributes + return cls(int_data, scale, zero_point, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]) + if self.zero_point is not None + else None, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point) if self.zero_point is not None else None, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + elif func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + elif func is aten.copy_.default: + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" + ) + + elif func is aten.t.default: + tensor = args[0] + new = tensor.__class__( + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), + ) + elif dim == 1: + assert ( + len(self.scale.shape) == 1 + ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return QDQTensorImpl( + aten.slice.Tensor(self.int_data, dim, start, end, step), + self.scale.view(-1), + self.zero_point.view(-1) if self.zero_point is not None else None, + self._layout, + ) + else: + raise NotImplementedError( + f"QDQTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"QDQTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.int_data, self.scale, self.zero_point + + def get_layout(self) -> Layout: + return self._layout + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, QDQLayout) + return cls(int_data, scale, zero_point, _layout) def _linear_check(input_tensor, weight_tensor, bias):