Skip to content

[Not for landing] quant recipe for phi4-mini #2038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions quant_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Quantization
import lm_eval
import torch
from lm_eval import evaluator
from lm_eval.utils import (
make_table,
)
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
)

model_id = "microsoft/Phi-4-mini-instruct"
ENABLE_GPTQ = False
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
MappingType,
ZeroPointDomain,
quantize_,
)

model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype="auto", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id, padding_side="right")

embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int4, # torch.int8
granularity=PerGroup(32), #PerAxis(0),
mapping_type=MappingType.ASYMMETRIC,
zero_point_domain=ZeroPointDomain.INT,
# scale_dtype=torch.float32,
)
linear_config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(128),
weight_mapping_type=MappingType.SYMMETRIC,
weight_zero_point_domain=ZeroPointDomain.NONE,
# weight_scale_dtype=torch.bfloat16,
)

quantize_(model, embedding_config, lambda m, fqn: isinstance(m, torch.nn.Embedding))

if not ENABLE_GPTQ:
quantize_(
model,
linear_config,
)
else:
# This needs work
assert False, "GPTQ is not set up yet"
calibration_tasks = ["hellaswag"]
calibration_seq_length = 80
calibration_limit = 1000
pad_calibration_inputs = True
from torchao._models._eval import QwenMultiTensorInputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer

# Assert Int4WeightOnlyConfig compatibility
# TODO: check if this is identical to Int4WeightOnlyGPTQQuantizer (other than dynamic activation bit)
assert linear_config.weight_dtype == torch.int4
assert isinstance(linear_config.weight_granularity, PerGroup)
groupsize = linear_config.granularity.group_size
assert linear_config.weight_zero_point_domain == ZeroPointDomain.NONE
assert linear_config.weight_mapping_type == MappingType.ASYMMETRIC

device = "cuda"
assert groupsize in [
32,
64,
128,
256,
], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
inputs = (
QwenMultiTensorInputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
)
.record_inputs(
calibration_tasks,
calibration_limit,
)
.get_inputs()
)
quantizer = Int4WeightOnlyGPTQQuantizer(group_size=groupsize, device=device)
model = quantizer.quantize(model, kw_inputs=inputs).to(device)

# prompt = "Hey, are you conscious? Can you talk to me?"
# inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

messages = [
{
"role": "system",
"content": "You are a medieval knight and must provide explanations to modern people.",
},
{"role": "user", "content": "How should I explain the Internet?"},
]
inputs = tokenizer(
tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
),
return_tensors="pt",
).to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text[0])

# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md
lm_eval_model = lm_eval.models.huggingface.HFLM(pretrained=model, batch_size=64)
results = evaluator.simple_evaluate(
lm_eval_model, tasks=["hellaswag"], device="cuda", batch_size="auto"
)
print(make_table(results))
190 changes: 185 additions & 5 deletions torchao/dtypes/uintx/q_dq_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,199 @@
logger.addHandler(handler)


from torchao.dtypes.utils import PlainLayout
from dataclasses import dataclass
from typing import Optional, Tuple

from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)

from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.utils import fill_defaults

aten = torch.ops.aten

class QDQLayout(PlainLayout):

@dataclass(frozen=True)
class QDQLayout(Layout):
pass


from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
def _same_metadata(self: "QDQTensorImpl", src: "QDQTensorImpl") -> bool:
return (
isinstance(self, QDQTensorImpl)
and isinstance(src, QDQTensorImpl)
and self.shape == src.shape
and self.int_data.shape == src.int_data.shape
and self.scale.shape == src.scale.shape
and (self.zero_point is None and src.zero_point is None)
or (
self.zero_point is not None
and src.zero_point is not None
and self.zero_point.shape == src.zero_point.shape
)
and type(self._layout) == type(src._layout)
)


@register_layout(QDQLayout)
class _Impl(PlainAQTTensorImpl):
pass
class QDQTensorImpl(AQTTensorImpl):
"""
TensorImpl for QDQLayout layout for affine quantized tensor, it stores int_data, scale, zero_point
tensors directly as plain tensors.

fields:
int_data (torch.Tensor): the quantized integer data Tensor
scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor
zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor
"""

def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point
self._layout = _layout

def __tensor_flatten__(self):
if self.zero_point is None:
return ["int_data", "scale"], [self._layout]
return ["int_data", "scale", "zero_point"], [self._layout]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = (
tensor_data_dict["int_data"],
tensor_data_dict["scale"],
tensor_data_dict.get("zero_point", None),
)
(_layout,) = tensor_attributes
return cls(int_data, scale, zero_point, _layout)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"])
if self.zero_point is not None
else None,
self._layout,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point) if self.zero_point is not None else None,
self._layout,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

elif func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

elif func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

elif func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout
)
return return_and_correct_aliasing(func, args, kwargs, new)

elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0]._apply_fn_to_data(
lambda x: aten.slice.Tensor(x, dim, start, end, step)
),
)
elif dim == 1:
assert (
len(self.scale.shape) == 1
), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
return QDQTensorImpl(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
self.zero_point.view(-1) if self.zero_point is not None else None,
self._layout,
)
else:
raise NotImplementedError(
f"QDQTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)

raise NotImplementedError(
f"QDQTensorImpl dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.int_data, self.scale, self.zero_point

def get_layout(self) -> Layout:
return self._layout

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
assert isinstance(_layout, QDQLayout)
return cls(int_data, scale, zero_point, _layout)


def _linear_check(input_tensor, weight_tensor, bias):
Expand Down
Loading