|
36 | 36 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa |
37 | 37 | from torchao.quantization.quant_api import ( |
38 | 38 | int4_weight_only, |
| 39 | + int8_weight_only, |
39 | 40 | Int4WeightOnlyQuantizer, |
40 | 41 | Int8DynActInt4WeightQuantizer, |
41 | 42 | quantize_, |
@@ -110,12 +111,20 @@ def quantize_model( |
110 | 111 | if quantizer not in quantizer_class_dict: |
111 | 112 | raise RuntimeError(f"unknown quantizer {quantizer} specified") |
112 | 113 | else: |
| 114 | + ao_quant = True |
113 | 115 | # Use tensor subclass API for int4 weight only. |
114 | 116 | if device == "cuda" and quantizer == "linear:int4": |
115 | 117 | quantize_(model, int4_weight_only(q_kwargs["groupsize"])) |
| 118 | + elif quantizer == "linear:int8": |
| 119 | + print("quantizer is linear int8") |
| 120 | + quantize_(model, int8_weight_only()) |
| 121 | + else: |
| 122 | + ao_quant = False |
| 123 | + if ao_quant: |
116 | 124 | if not support_tensor_subclass: |
117 | 125 | unwrap_tensor_subclass(model) |
118 | 126 | continue |
| 127 | + |
119 | 128 |
|
120 | 129 | if quantizer in ["linear:a8wxdq", "embedding:wx"]: |
121 | 130 | # These quantizers require float32 input weights. Note that after quantization, |
@@ -529,147 +538,6 @@ def linear_int8_et(input, weight, scales): |
529 | 538 | ) |
530 | 539 |
|
531 | 540 |
|
532 | | -class WeightOnlyInt8Linear(nn.Module): |
533 | | - __constants__ = ["in_features", "out_features"] |
534 | | - in_features: int |
535 | | - out_features: int |
536 | | - weight: torch.Tensor |
537 | | - scales: torch.Tensor |
538 | | - |
539 | | - def __init__( |
540 | | - self, |
541 | | - in_features, |
542 | | - out_features, |
543 | | - bias=None, |
544 | | - device=None, |
545 | | - dtype=None, |
546 | | - *, |
547 | | - weight: Optional[torch.Tensor] = None, |
548 | | - scales: Optional[torch.Tensor] = None, |
549 | | - groupsize: Optional[int] = None, |
550 | | - ): |
551 | | - super().__init__() |
552 | | - if dtype is None: |
553 | | - dtype = torch.get_default_dtype() |
554 | | - |
555 | | - if device is None: |
556 | | - device = "cpu" |
557 | | - |
558 | | - assert not bias, "Bias is not supported by LinearInt8" |
559 | | - self.in_features = in_features |
560 | | - self.out_features = out_features |
561 | | - |
562 | | - assert (weight is None) == bool( |
563 | | - scales is None |
564 | | - ), "must specify both weights and scales, or neither" |
565 | | - if weight is None: |
566 | | - weight = torch.empty( |
567 | | - (out_features, in_features), |
568 | | - dtype=torch.int8, |
569 | | - device=device, |
570 | | - ) |
571 | | - if groupsize is None or (groupsize == 0): |
572 | | - scales = torch.empty(out_features, dtype=dtype, device=device) |
573 | | - else: |
574 | | - n_groups = (in_features + groupsize - 1) // groupsize |
575 | | - scales = torch.empty(out_features, n_groups, dtype=dtype, device=device) |
576 | | - |
577 | | - self.register_buffer("weight", weight.to(device)) |
578 | | - self.register_buffer("scales", scales.to(device)) |
579 | | - |
580 | | - if use_et_backend(): |
581 | | - self.forward = self.et_forward |
582 | | - else: |
583 | | - self.forward = self.aoti_forward |
584 | | - |
585 | | - def aoti_forward(self, input: torch.Tensor) -> torch.Tensor: |
586 | | - return linear_int8_aoti(input, self.weight, self.scales) |
587 | | - |
588 | | - def et_forward(self, input: torch.Tensor) -> torch.Tensor: |
589 | | - return linear_int8_et(input, self.weight, self.scales) |
590 | | - |
591 | | - |
592 | | -class WeightOnlyInt8QuantHandler(QuantHandler): |
593 | | - def __init__( |
594 | | - self, |
595 | | - model: Optional[nn.Module] = None, |
596 | | - device = None, |
597 | | - precision=None, |
598 | | - tokenizer=None, |
599 | | - *, |
600 | | - node_type: str = "*", |
601 | | - bitwidth: Optional[int] = None, |
602 | | - groupsize: Optional[int] = None, |
603 | | - ): |
604 | | - self.model_ = model |
605 | | - self.device = device |
606 | | - self.groupsize = groupsize |
607 | | - self.node_type = node_type |
608 | | - if bitwidth is None: |
609 | | - self.bitwidth = 8 |
610 | | - else: |
611 | | - self.bitwidth = bitwidth |
612 | | - |
613 | | - @torch.no_grad() |
614 | | - def quantize(self, module): |
615 | | - # cur_state_dict = state_dict_device(self.model_.state_dict()) |
616 | | - # dict_device = "cpu" # self.device |
617 | | - |
618 | | - if self.bitwidth == 4: |
619 | | - range_min = -8 |
620 | | - range_max = 7 |
621 | | - elif self.bitwidth == 8: |
622 | | - range_min = -128 |
623 | | - range_max = 127 |
624 | | - else: |
625 | | - raise ValueError(f"Unsupported bitwidth {self.bitwidth}") |
626 | | - |
627 | | - for name, child in module.named_children(): |
628 | | - # print(f"name: {name}") |
629 | | - if isinstance(child, nn.Linear): |
630 | | - if ( |
631 | | - (self.node_type == "*") |
632 | | - or (self.node_type == "output" and name == "output") |
633 | | - or (self.node_type == "!output" and name != "output") |
634 | | - ): |
635 | | - # print(f"{name, child}") |
636 | | - input_weight = child.weight.float() |
637 | | - # print(f"{name, child}") |
638 | | - # print(f"in_features: {child.in_features}") |
639 | | - # print(f"out_features: {child.out_features}") |
640 | | - |
641 | | - # print(f"expanded weight shape {input_weight.shape}") |
642 | | - weight, scales, _ = dynamically_quantize_per_channel( |
643 | | - input_weight, |
644 | | - range_min, |
645 | | - range_max, |
646 | | - torch.int8, |
647 | | - self.groupsize, |
648 | | - scales_dtype=child.weight.dtype, |
649 | | - ) |
650 | | - |
651 | | - setattr( |
652 | | - module, |
653 | | - name, |
654 | | - WeightOnlyInt8Linear( |
655 | | - in_features=child.in_features, |
656 | | - out_features=child.out_features, |
657 | | - device=self.device, |
658 | | - # update variables from quantization |
659 | | - weight=weight, |
660 | | - scales=scales, |
661 | | - groupsize=self.groupsize, |
662 | | - ), |
663 | | - ) |
664 | | - else: |
665 | | - self.quantize(child) |
666 | | - |
667 | | - return module |
668 | | - |
669 | | - def quantized_model(self) -> nn.Module: |
670 | | - return self.quantize(self.model_) |
671 | | - |
672 | | - |
673 | 541 | ######################################################################### |
674 | 542 | ##### embedding table quantization ###### |
675 | 543 | ### (unify with torchao in future) ### |
@@ -886,10 +754,10 @@ def quantized_model(self) -> nn.Module: |
886 | 754 | # class references |
887 | 755 | quantizer_class_dict = { |
888 | 756 | "embedding": EmbeddingOnlyQuantHandler, |
889 | | - "linear:int8": WeightOnlyInt8QuantHandler, |
890 | 757 | "precision": PrecisionHandler, |
891 | 758 | "executor": ExecutorHandler, |
892 | 759 | "linear:int4": Int4WeightOnlyQuantizer, |
| 760 | + "linear:int8": int8_weight_only, |
893 | 761 | "linear:a8w4dq": Int8DynActInt4WeightQuantizer, |
894 | 762 | } |
895 | 763 |
|
@@ -917,6 +785,7 @@ def quantized_model(self) -> nn.Module: |
917 | 785 | IntxWeightEmbeddingQuantizer, |
918 | 786 | ) |
919 | 787 |
|
| 788 | + |
920 | 789 | quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer |
921 | 790 | quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer |
922 | 791 |
|
|
0 commit comments