-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Intel GPU] Docs of XPUInductorQuantizer #3293
base: main
Are you sure you want to change the base?
Changes from 15 commits
19a568a
6bfa4d7
a8e8d8a
63a63cc
1b3cf01
6a9640a
57985bf
4c4069a
5a6663a
79d56de
6a4f748
e7e2275
d741cf9
9a0f0c7
aae7c51
3a4ecf4
584352c
b31cc28
a84130e
dd328d1
83fc829
4a158d9
c209663
e95b42b
011ef20
98ecedc
e2d9fbf
21a4491
204eb89
a55cb48
046b04f
f92a55b
384b342
6a43ce0
a6516c9
d04de57
8e0293a
980079c
640fa94
17ad15a
41fc5b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,230 @@ | ||||||
PyTorch 2 Export Quantization with Intel GPU Backend through Inductor | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Intel XPU There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
================================================================== | ||||||
|
||||||
**Author**: `Yan Zhiwei <https://github.com/ZhiweiYan-96>`_, `Wang Eikan <https://github.com/EikanWang>`_, `Zhang Liangang <https://github.com/liangan1>`_, `Liu River <https://github.com/riverliuintel>`_, `Cui Yifeng <https://github.com/CuiYifeng>`_ | ||||||
|
||||||
Prerequisites | ||||||
--------------- | ||||||
|
||||||
- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_ | ||||||
- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Introduction | ||||||
-------------- | ||||||
|
||||||
This tutorial introduces XPUInductorQuantizer aiming for serving the quantized model inference on Intel GPUs. The tutorial will cover how it | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
utilizes PyTorch 2 Export Quantization flow and lowers the quantized model into the inductor. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are you trying to say in this phrase: "lowers the quantized model into the inductor"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the terminology in torch.compile |
||||||
|
||||||
The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and perform quantization transformations on top of the ATen graph. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
This approach is expected to have significantly higher model coverage with better programmability and a simplified user experience. | ||||||
TorchInductor is the compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
The quantization flow mainly includes three steps: | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
- Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
- Step 2: Apply the quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers, | ||||||
performing the prepared model's calibration, and converting the prepared model into the quantized model. | ||||||
- Step 3: Lower the quantized model into inductor with the API ``torch.compile``, which would call triton kernels or oneDNN GEMM/Conv kernels. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does it mean to "Lower the quantized model into inductor"? |
||||||
|
||||||
|
||||||
The high-level architecture of this flow could look like this: | ||||||
|
||||||
.. image:: ../_static/img/pt2e_quant_xpu_inductor.png | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please note that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for reminding, the pictures is moidified |
||||||
:align: center | ||||||
|
||||||
Post Training Quantization | ||||||
---------------------------- | ||||||
|
||||||
Static quantization is the only method we support currently. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
The dependencies packages are recommended to be installed through Intel GPU channel as follows | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
:: | ||||||
|
||||||
pip3 install torch torchvision torchaudio pytorch-triton-xpu --index-url https://download.pytorch.org/whl/xpu | ||||||
|
||||||
1. Capture FX Graph | ||||||
^^^^^^^^^^^^^^^^^^^^^ | ||||||
|
||||||
We will start by performing the necessary imports, capturing the FX Graph from the eager module. | ||||||
|
||||||
:: | ||||||
|
||||||
import torch | ||||||
import torchvision.models as models | ||||||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e | ||||||
import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq | ||||||
from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer | ||||||
from torch.export import export_for_training | ||||||
|
||||||
# Create the Eager Model | ||||||
model_name = "resnet18" | ||||||
model = models.__dict__[model_name](weights=models.ResNet18_Weights.DEFAULT) | ||||||
|
||||||
# Set the model to eval mode | ||||||
model = model.eval().to("xpu") | ||||||
|
||||||
# Create the data, using the dummy data here as an example | ||||||
traced_bs = 50 | ||||||
x = torch.randn(traced_bs, 3, 224, 224, device="xpu").contiguous(memory_format=torch.channels_last) | ||||||
example_inputs = (x,) | ||||||
|
||||||
# Capture the FX Graph to be quantized | ||||||
with torch.no_grad(): | ||||||
exported_model = export_for_training( | ||||||
model, | ||||||
example_inputs, | ||||||
).module() | ||||||
|
||||||
|
||||||
Next, we will have the FX Module to be quantized. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
2. Apply Quantization | ||||||
^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|
||||||
After we capture the FX Module to be quantized, we will import the Backend Quantizer for Intel GPU and configure how to | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
quantize the model. | ||||||
|
||||||
:: | ||||||
|
||||||
quantizer = XPUInductorQuantizer() | ||||||
quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) | ||||||
|
||||||
The default quantization configuration in ``XPUInductorQuantizer`` uses signed 8-bits for both activations and weights. The tensor is per-tensor quantized, while the weight is signed 8-bit per-channel quantized. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Besides the default quant configuration (asymmetric quantized activation), we also support signed 8-bits symmetric quantized activation, which has the potential to provide better performance. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
:: | ||||||
|
||||||
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver | ||||||
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec | ||||||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig | ||||||
from typing import Any, Optional, TYPE_CHECKING | ||||||
if TYPE_CHECKING: | ||||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor | ||||||
def get_xpu_inductor_symm_quantization_config(): | ||||||
extra_args: dict[str, Any] = {"eps": 2**-12} | ||||||
act_observer_or_fake_quant_ctr = HistogramObserver | ||||||
act_quantization_spec = QuantizationSpec( | ||||||
dtype=torch.int8, | ||||||
quant_min=-128, | ||||||
quant_max=127, | ||||||
qscheme=torch.per_tensor_symmetric, # Change the activation quant config to symmetric | ||||||
is_dynamic=False, | ||||||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( | ||||||
**extra_args | ||||||
), | ||||||
) | ||||||
|
||||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( | ||||||
PerChannelMinMaxObserver | ||||||
) | ||||||
|
||||||
weight_quantization_spec = QuantizationSpec( | ||||||
dtype=torch.int8, | ||||||
quant_min=-128, | ||||||
quant_max=127, | ||||||
qscheme=torch.per_channel_symmetric, # Same as the default config, the only supported option for weight | ||||||
ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv | ||||||
is_dynamic=False, | ||||||
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( | ||||||
**extra_args | ||||||
), | ||||||
) | ||||||
|
||||||
bias_quantization_spec = None # will use placeholder observer by default | ||||||
quantization_config = QuantizationConfig( | ||||||
act_quantization_spec, | ||||||
act_quantization_spec, | ||||||
weight_quantization_spec, | ||||||
bias_quantization_spec, | ||||||
False, | ||||||
) | ||||||
return quantization_config | ||||||
|
||||||
Then, we can set the quantization configuration to the quantizer. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
:: | ||||||
|
||||||
quantizer = XPUInductorQuantizer() | ||||||
quantizer.set_global(get_xpu_inductor_symm_quantization_config()) | ||||||
|
||||||
After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
:: | ||||||
|
||||||
prepared_model = prepare_pt2e(exported_model, quantizer) | ||||||
|
||||||
Now, we will calibrate the ``prepared_model`` after the observers are inserted in the model. This step is needed for static quantization only. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
:: | ||||||
|
||||||
# We use the dummy data as an example here | ||||||
prepared_model(*example_inputs) | ||||||
|
||||||
# Alternatively: user can define the dataset to calibrate | ||||||
# def calibrate(model, data_loader): | ||||||
# model.eval() | ||||||
# with torch.no_grad(): | ||||||
# for image, target in data_loader: | ||||||
# model(image) | ||||||
# calibrate(prepared_model, data_loader_test) # run calibration on sample data | ||||||
|
||||||
Finally, we will convert the calibrated Model to a quantized Model. ``convert_pt2e`` takes a calibrated model and produces a quantized model. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
:: | ||||||
|
||||||
converted_model = convert_pt2e(prepared_model) | ||||||
|
||||||
After these steps, we finished running the quantization flow and we will get the quantized model. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
3. Lower into Inductor | ||||||
^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|
||||||
After we get the quantized model, we will further lower it to the inductor backend. | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
:: | ||||||
|
||||||
with torch.no_grad(): | ||||||
optimized_model = torch.compile(converted_model) | ||||||
|
||||||
# Running some benchmark | ||||||
optimized_model(*example_inputs) | ||||||
|
||||||
In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In this instance, | ||||||
a Convolution or GEMM operator produces BFloat16 output data type instead of Float32 in the absence | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through | ||||||
subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance. | ||||||
The utilization of this feature mirrors that of regular BFloat16 Autocast, as simple as wrapping the | ||||||
script within the BFloat16 Autocast context. | ||||||
|
||||||
:: | ||||||
|
||||||
with torch.amp.autocast(device_type="xpu", dtype=torch.bfloat16), torch.no_grad(): | ||||||
# Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into indcutor backend, | ||||||
# For operators such as QConvolution and QLinear: | ||||||
# * The input data type is consistently defined as int8, attributable to the presence of a pair | ||||||
# of quantization and dequantization nodes inserted at the input. | ||||||
# * The computation precision remains at int8. | ||||||
# * The output data type may vary, being either int8 or BFloat16, contingent on the presence | ||||||
# of a pair of quantization and dequantization nodes at the output. | ||||||
# For non-quantizable pointwise operators, the data type will be inherited from the previous node, | ||||||
# potentially resulting in a data type of BFloat16 in this scenario. | ||||||
# For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8 | ||||||
# data type for both input and output. | ||||||
optimized_model = torch.compile(converted_model) | ||||||
|
||||||
# Running some benchmark | ||||||
optimized_model(*example_inputs) | ||||||
ZhiweiYan-96 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
Put all these codes together, we will have the toy example code. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to have this sentence, just delete it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for advice, removed. |
||||||
Please note that since the Inductor ``freeze`` feature does not turn on by default yet, run your example code with ``TORCHINDUCTOR_FREEZING=1``. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps it might be better to put this near the top so developers are aware they need to run with this env variable set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for suggestions, I moved to the top of example code section. |
||||||
|
||||||
For example: | ||||||
|
||||||
:: | ||||||
|
||||||
TORCHINDUCTOR_FREEZING=1 python xpu_inductor_quantizer_example.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intel XPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At previous stage when we upload RFCs, we recommend using GPU instead of XPU for readability for users. Do we have some changes on this description desicsion?