|
| 1 | +from typing import Tuple |
| 2 | +import torch |
| 3 | + |
| 4 | + |
| 5 | +class ToyModel(torch.nn.Module): |
| 6 | + def __init__(self, m: int, n: int, k: int): |
| 7 | + super().__init__() |
| 8 | + self.linear1 = torch.nn.Linear(m, n, bias=False) |
| 9 | + self.linear2 = torch.nn.Linear(n, k, bias=False) |
| 10 | + |
| 11 | + def forward(self, x): |
| 12 | + x = self.linear1(x) |
| 13 | + x = self.linear2(x) |
| 14 | + return x |
| 15 | + |
| 16 | + |
| 17 | +class QuantizedLinear(torch.nn.Linear): |
| 18 | + """ |
| 19 | + Linear module that performs dynamic and symmetric weight-only |
| 20 | + int8 quantization. |
| 21 | + """ |
| 22 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 23 | + w_int8, scale = int8_symmetric_quantize(self.weight) |
| 24 | + return torch.mm(x, w_int8.t().to(x.dtype)) * scale.t() |
| 25 | + |
| 26 | + @classmethod |
| 27 | + def from_float(cls, mod: torch.nn.Linear): |
| 28 | + new_linear = cls(mod.in_features, mod.out_features, mod.bias) |
| 29 | + new_linear.weight = mod.weight |
| 30 | + return new_linear |
| 31 | + |
| 32 | + |
| 33 | +def int8_symmetric_quantize( |
| 34 | + fp32_tensor: torch.Tensor, |
| 35 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 36 | + """ |
| 37 | + Symmetrically quantize the torch.float32 tensor into torch.int8. |
| 38 | + Return a 2-tuple of (quantized value, scale). |
| 39 | + """ |
| 40 | + quant_min = -128 |
| 41 | + quant_max = 127 |
| 42 | + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) |
| 43 | + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) |
| 44 | + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
| 45 | + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
| 46 | + max_val_pos = torch.max(-min_val_neg, max_val_pos) |
| 47 | + scale = max_val_pos / (float(quant_max - quant_min) / 2) |
| 48 | + scale = scale.view(fp32_tensor.shape[0], -1) |
| 49 | + out = torch.round(fp32_tensor * (1.0 / scale)) |
| 50 | + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) |
| 51 | + return out, scale |
| 52 | + |
| 53 | + |
| 54 | +if __name__ == "__main__": |
| 55 | + model = ToyModel(64, 128, 32).cuda() |
| 56 | + example_inputs = torch.randn((1, 64), dtype=torch.float32, device="cuda") |
| 57 | + |
| 58 | + # Swap torch.nn.Linear with QuantizedLinear |
| 59 | + for name, child in model.named_children(): |
| 60 | + if type(child) == torch.nn.Linear: |
| 61 | + new_linear = QuantizedLinear.from_float(child) |
| 62 | + setattr(model, name, new_linear) |
| 63 | + |
| 64 | + print("quantized model: ", model) |
| 65 | + print("output: ", model(example_inputs)) |
0 commit comments