|
| 1 | +PyTorch 2 Export Quantization with Intel GPU Backend through Inductor |
| 2 | +================================================================== |
| 3 | + |
| 4 | +**Author**: `Yan Zhiwei <https://github.com/ZhiweiYan-96>`_, `Wang Eikan <https://github.com/EikanWang>`_, `Zhang Liangang <https://github.com/liangan1>`_, `Liu River <https://github.com/riverliuintel>`_, `Cui Yifeng <https://github.com/CuiYifeng>`_ |
| 5 | + |
| 6 | +Prerequisites |
| 7 | +--------------- |
| 8 | + |
| 9 | +- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_ |
| 10 | +- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ |
| 11 | +- PyTorch 2.7 or later |
| 12 | + |
| 13 | +Introduction |
| 14 | +-------------- |
| 15 | + |
| 16 | +This tutorial introduces ``XPUInductorQuantizer``, which aims to serve quantized models for inference on Intel GPUs. |
| 17 | +``XPUInductorQuantizer`` uses the PyTorch Export Quantization flow and lowers the quantized model into the inductor. |
| 18 | + |
| 19 | +The Pytorch 2 Export Quantization flow uses `torch.export` to capture the model into a graph and perform quantization transformations on top of the ATen graph. |
| 20 | +This approach is expected to have significantly higher model coverage with better programmability and a simplified user experience. |
| 21 | +TorchInductor is a compiler backend that transforms FX Graphs generated by ``TorchDynamo`` into optimized C++/Triton kernels. |
| 22 | + |
| 23 | +The quantization flow has three steps: |
| 24 | + |
| 25 | +- Step 1: Capture the FX Graph from the eager model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_. |
| 26 | +- Step 2: Apply the quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers, |
| 27 | + performing the prepared model's calibration, and converting the prepared model into the quantized model. |
| 28 | +- Step 3: Lower the quantized model into inductor with the API ``torch.compile``, which would call Triton kernels or oneDNN GEMM/Convolution kernels. |
| 29 | + |
| 30 | + |
| 31 | +The high-level architecture of this flow could look like this: |
| 32 | + |
| 33 | +.. image:: ../_static/img/pt2e_quant_xpu_inductor.png |
| 34 | + :align: center |
| 35 | + |
| 36 | +Post Training Quantization |
| 37 | +---------------------------- |
| 38 | + |
| 39 | +Static quantization is the only method we currently support. |
| 40 | + |
| 41 | +The following dependencies are recommended to be installed through the Intel GPU channel: |
| 42 | + |
| 43 | +:: |
| 44 | + |
| 45 | + pip3 install torch torchvision torchaudio pytorch-triton-xpu --index-url https://download.pytorch.org/whl/xpu |
| 46 | + |
| 47 | + |
| 48 | +Please note that since the inductor ``freeze`` feature does not turn on by default yet, you must run your example code with ``TORCHINDUCTOR_FREEZING=1``. |
| 49 | + |
| 50 | +For example: |
| 51 | + |
| 52 | +:: |
| 53 | + |
| 54 | + TORCHINDUCTOR_FREEZING=1 python xpu_inductor_quantizer_example.py |
| 55 | + |
| 56 | + |
| 57 | +1. Capture FX Graph |
| 58 | +^^^^^^^^^^^^^^^^^^^^^ |
| 59 | + |
| 60 | +We will start by performing the necessary imports, capturing the FX Graph from the eager module. |
| 61 | + |
| 62 | +:: |
| 63 | + |
| 64 | + import torch |
| 65 | + import torchvision.models as models |
| 66 | + from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e |
| 67 | + import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq |
| 68 | + from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer |
| 69 | + from torch.export import export_for_training |
| 70 | + |
| 71 | + # Create the Eager Model |
| 72 | + model_name = "resnet18" |
| 73 | + model = models.__dict__[model_name](weights=models.ResNet18_Weights.DEFAULT) |
| 74 | + |
| 75 | + # Set the model to eval mode |
| 76 | + model = model.eval().to("xpu") |
| 77 | + |
| 78 | + # Create the data, using the dummy data here as an example |
| 79 | + traced_bs = 50 |
| 80 | + x = torch.randn(traced_bs, 3, 224, 224, device="xpu").contiguous(memory_format=torch.channels_last) |
| 81 | + example_inputs = (x,) |
| 82 | + |
| 83 | + # Capture the FX Graph to be quantized |
| 84 | + with torch.no_grad(): |
| 85 | + exported_model = export_for_training( |
| 86 | + model, |
| 87 | + example_inputs, |
| 88 | + ).module() |
| 89 | + |
| 90 | + |
| 91 | +Next, we will quantize the FX Module. |
| 92 | + |
| 93 | +2. Apply Quantization |
| 94 | +^^^^^^^^^^^^^^^^^^^^^^^ |
| 95 | + |
| 96 | +After we capture the FX Module, we will import the Backend Quantizer for Intel GPU and configure it to |
| 97 | +quantize the model. |
| 98 | + |
| 99 | +:: |
| 100 | + |
| 101 | + quantizer = XPUInductorQuantizer() |
| 102 | + quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) |
| 103 | + |
| 104 | +The default quantization configuration in ``XPUInductorQuantizer`` uses signed 8-bits for both activations and weights. The tensors are per-tensor quantized, whereas the weights are signed 8-bit per-channel quantized. |
| 105 | + |
| 106 | +Optionally, in addition to the default quantization configuration using asymmetric quantized activation, signed 8-bits symmetric quantized activation is also supported, which has the potential to provide better performance. |
| 107 | + |
| 108 | +:: |
| 109 | + |
| 110 | + from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver |
| 111 | + from torch.ao.quantization.quantizer.quantizer import QuantizationSpec |
| 112 | + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig |
| 113 | + from typing import Any, Optional, TYPE_CHECKING |
| 114 | + if TYPE_CHECKING: |
| 115 | + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor |
| 116 | + def get_xpu_inductor_symm_quantization_config(): |
| 117 | + extra_args: dict[str, Any] = {"eps": 2**-12} |
| 118 | + act_observer_or_fake_quant_ctr = HistogramObserver |
| 119 | + act_quantization_spec = QuantizationSpec( |
| 120 | + dtype=torch.int8, |
| 121 | + quant_min=-128, |
| 122 | + quant_max=127, |
| 123 | + qscheme=torch.per_tensor_symmetric, # Change the activation quant config to symmetric |
| 124 | + is_dynamic=False, |
| 125 | + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( |
| 126 | + **extra_args |
| 127 | + ), |
| 128 | + ) |
| 129 | + |
| 130 | + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( |
| 131 | + PerChannelMinMaxObserver |
| 132 | + ) |
| 133 | + |
| 134 | + weight_quantization_spec = QuantizationSpec( |
| 135 | + dtype=torch.int8, |
| 136 | + quant_min=-128, |
| 137 | + quant_max=127, |
| 138 | + qscheme=torch.per_channel_symmetric, # Same as the default config, the only supported option for weight |
| 139 | + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv |
| 140 | + is_dynamic=False, |
| 141 | + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( |
| 142 | + **extra_args |
| 143 | + ), |
| 144 | + ) |
| 145 | + |
| 146 | + bias_quantization_spec = None # will use placeholder observer by default |
| 147 | + quantization_config = QuantizationConfig( |
| 148 | + act_quantization_spec, |
| 149 | + act_quantization_spec, |
| 150 | + weight_quantization_spec, |
| 151 | + bias_quantization_spec, |
| 152 | + False, |
| 153 | + ) |
| 154 | + return quantization_config |
| 155 | + |
| 156 | + # Then, set the quantization configuration to the quantizer. |
| 157 | + quantizer = XPUInductorQuantizer() |
| 158 | + quantizer.set_global(get_xpu_inductor_symm_quantization_config()) |
| 159 | + |
| 160 | +After the backend-specific quantizer is imported, prepare the model for post-training quantization. |
| 161 | +``prepare_pt2e`` folds ``BatchNorm`` operators into preceding Conv2d operators, and inserts observers into appropriate places in the model. |
| 162 | + |
| 163 | +:: |
| 164 | + |
| 165 | + prepared_model = prepare_pt2e(exported_model, quantizer) |
| 166 | + |
| 167 | +**(For static quantization only)** Calibrate the ``prepared_model`` after the observers are inserted into the model. |
| 168 | + |
| 169 | +:: |
| 170 | + |
| 171 | + # We use the dummy data as an example here |
| 172 | + prepared_model(*example_inputs) |
| 173 | + |
| 174 | + # Alternatively: user can define the dataset to calibrate |
| 175 | + # def calibrate(model, data_loader): |
| 176 | + # model.eval() |
| 177 | + # with torch.no_grad(): |
| 178 | + # for image, target in data_loader: |
| 179 | + # model(image) |
| 180 | + # calibrate(prepared_model, data_loader_test) # run calibration on sample data |
| 181 | + |
| 182 | +Finally, convert the calibrated model to a quantized model. ``convert_pt2e`` takes a calibrated model and produces a quantized model. |
| 183 | + |
| 184 | +:: |
| 185 | + |
| 186 | + converted_model = convert_pt2e(prepared_model) |
| 187 | + |
| 188 | +After these steps, the quantization flow has been completed and the quantized model is available. |
| 189 | + |
| 190 | + |
| 191 | +3. Lower into Inductor |
| 192 | +^^^^^^^^^^^^^^^^^^^^^^^^ |
| 193 | + |
| 194 | +The quantized model will then be lowered into the inductor backend. |
| 195 | + |
| 196 | +:: |
| 197 | + |
| 198 | + with torch.no_grad(): |
| 199 | + optimized_model = torch.compile(converted_model) |
| 200 | + |
| 201 | + # Running some benchmark |
| 202 | + optimized_model(*example_inputs) |
| 203 | + |
| 204 | +In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In this instance, |
| 205 | +a convolution or GEMM operator produces the output in BFloat16 instead of Float32 in the absence |
| 206 | +of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through |
| 207 | +subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance. |
| 208 | +The utilization of this feature mirrors that of regular BFloat16 Autocast, as simple as wrapping the |
| 209 | +script within the BFloat16 Autocast context. |
| 210 | + |
| 211 | +:: |
| 212 | + |
| 213 | + with torch.amp.autocast(device_type="xpu", dtype=torch.bfloat16), torch.no_grad(): |
| 214 | + # Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into indcutor backend, |
| 215 | + # For operators such as QConvolution and QLinear: |
| 216 | + # * The input data type is consistently defined as int8, attributable to the presence of a pair |
| 217 | + # of quantization and dequantization nodes inserted at the input. |
| 218 | + # * The computation precision remains at int8. |
| 219 | + # * The output data type may vary, being either int8 or BFloat16, contingent on the presence |
| 220 | + # of a pair of quantization and dequantization nodes at the output. |
| 221 | + # For non-quantizable pointwise operators, the data type will be inherited from the previous node, |
| 222 | + # potentially resulting in a data type of BFloat16 in this scenario. |
| 223 | + # For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8 |
| 224 | + # data type for both input and output. |
| 225 | + optimized_model = torch.compile(converted_model) |
| 226 | + |
| 227 | + # Running some benchmark |
| 228 | + optimized_model(*example_inputs) |
| 229 | + |
| 230 | + |
| 231 | +Conclusion |
| 232 | +------------ |
| 233 | + |
| 234 | +In this tutorial, we have learned how to utilize the ``XPUInductorQuantizer`` to perform post-training quantization on models for inference |
| 235 | +on Intel GPUs, leveraging PyTorch 2's Export Quantization flow. We covered the step-by-step process of capturing an FX Graph, |
| 236 | +applying quantization, and lowering the quantized model into the inductor backend using ``torch.compile``. Additionally, we explored |
| 237 | +the benefits of using int8-mixed-bf16 quantization for improved memory efficiency and potential performance gains, |
| 238 | +especially when using ``BFloat16`` autocast. |
0 commit comments