Skip to content

Commit 459084a

Browse files
ZhiweiYan-96xiaolil1svekarsalexsin368
authored
[Intel GPU] Docs of XPUInductorQuantizer (#3293)
* [Intel GPU] Docs of XPUInductorQuantizer --------- Co-authored-by: xiaolil1 <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: alexsin368 <[email protected]>
1 parent aebeff4 commit 459084a

File tree

3 files changed

+245
-0
lines changed

3 files changed

+245
-0
lines changed

Diff for: _static/img/pt2e_quant_xpu_inductor.png

117 KB
Loading

Diff for: prototype_source/prototype_index.rst

+7
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ Prototype features are not available as part of binary distributions like PyPI o
9696
:link: ../prototype/pt2e_quant_x86_inductor.html
9797
:tags: Quantization
9898

99+
.. customcarditem::
100+
:header: PyTorch 2 Export Quantization with Intel GPU Backend through Inductor
101+
:card_description: Learn how to use PT2 Export Quantization with Intel GPU Backend through Inductor.
102+
:image: _static/img/thumbnails/cropped/pytorch-logo.png
103+
:link: ../prototype/pt2e_quant_xpu_inductor.html
104+
:tags: Quantization
105+
99106
.. Sparsity
100107
101108
.. customcarditem::

Diff for: prototype_source/pt2e_quant_xpu_inductor.rst

+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
PyTorch 2 Export Quantization with Intel GPU Backend through Inductor
2+
==================================================================
3+
4+
**Author**: `Yan Zhiwei <https://github.com/ZhiweiYan-96>`_, `Wang Eikan <https://github.com/EikanWang>`_, `Zhang Liangang <https://github.com/liangan1>`_, `Liu River <https://github.com/riverliuintel>`_, `Cui Yifeng <https://github.com/CuiYifeng>`_
5+
6+
Prerequisites
7+
---------------
8+
9+
- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_
10+
- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
11+
- PyTorch 2.7 or later
12+
13+
Introduction
14+
--------------
15+
16+
This tutorial introduces ``XPUInductorQuantizer``, which aims to serve quantized models for inference on Intel GPUs.
17+
``XPUInductorQuantizer`` uses the PyTorch Export Quantization flow and lowers the quantized model into the inductor.
18+
19+
The Pytorch 2 Export Quantization flow uses `torch.export` to capture the model into a graph and perform quantization transformations on top of the ATen graph.
20+
This approach is expected to have significantly higher model coverage with better programmability and a simplified user experience.
21+
TorchInductor is a compiler backend that transforms FX Graphs generated by ``TorchDynamo`` into optimized C++/Triton kernels.
22+
23+
The quantization flow has three steps:
24+
25+
- Step 1: Capture the FX Graph from the eager model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_.
26+
- Step 2: Apply the quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers,
27+
performing the prepared model's calibration, and converting the prepared model into the quantized model.
28+
- Step 3: Lower the quantized model into inductor with the API ``torch.compile``, which would call Triton kernels or oneDNN GEMM/Convolution kernels.
29+
30+
31+
The high-level architecture of this flow could look like this:
32+
33+
.. image:: ../_static/img/pt2e_quant_xpu_inductor.png
34+
:align: center
35+
36+
Post Training Quantization
37+
----------------------------
38+
39+
Static quantization is the only method we currently support.
40+
41+
The following dependencies are recommended to be installed through the Intel GPU channel:
42+
43+
::
44+
45+
pip3 install torch torchvision torchaudio pytorch-triton-xpu --index-url https://download.pytorch.org/whl/xpu
46+
47+
48+
Please note that since the inductor ``freeze`` feature does not turn on by default yet, you must run your example code with ``TORCHINDUCTOR_FREEZING=1``.
49+
50+
For example:
51+
52+
::
53+
54+
TORCHINDUCTOR_FREEZING=1 python xpu_inductor_quantizer_example.py
55+
56+
57+
1. Capture FX Graph
58+
^^^^^^^^^^^^^^^^^^^^^
59+
60+
We will start by performing the necessary imports, capturing the FX Graph from the eager module.
61+
62+
::
63+
64+
import torch
65+
import torchvision.models as models
66+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
67+
import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq
68+
from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer
69+
from torch.export import export_for_training
70+
71+
# Create the Eager Model
72+
model_name = "resnet18"
73+
model = models.__dict__[model_name](weights=models.ResNet18_Weights.DEFAULT)
74+
75+
# Set the model to eval mode
76+
model = model.eval().to("xpu")
77+
78+
# Create the data, using the dummy data here as an example
79+
traced_bs = 50
80+
x = torch.randn(traced_bs, 3, 224, 224, device="xpu").contiguous(memory_format=torch.channels_last)
81+
example_inputs = (x,)
82+
83+
# Capture the FX Graph to be quantized
84+
with torch.no_grad():
85+
exported_model = export_for_training(
86+
model,
87+
example_inputs,
88+
).module()
89+
90+
91+
Next, we will quantize the FX Module.
92+
93+
2. Apply Quantization
94+
^^^^^^^^^^^^^^^^^^^^^^^
95+
96+
After we capture the FX Module, we will import the Backend Quantizer for Intel GPU and configure it to
97+
quantize the model.
98+
99+
::
100+
101+
quantizer = XPUInductorQuantizer()
102+
quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config())
103+
104+
The default quantization configuration in ``XPUInductorQuantizer`` uses signed 8-bits for both activations and weights. The tensors are per-tensor quantized, whereas the weights are signed 8-bit per-channel quantized.
105+
106+
Optionally, in addition to the default quantization configuration using asymmetric quantized activation, signed 8-bits symmetric quantized activation is also supported, which has the potential to provide better performance.
107+
108+
::
109+
110+
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
111+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
112+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
113+
from typing import Any, Optional, TYPE_CHECKING
114+
if TYPE_CHECKING:
115+
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
116+
def get_xpu_inductor_symm_quantization_config():
117+
extra_args: dict[str, Any] = {"eps": 2**-12}
118+
act_observer_or_fake_quant_ctr = HistogramObserver
119+
act_quantization_spec = QuantizationSpec(
120+
dtype=torch.int8,
121+
quant_min=-128,
122+
quant_max=127,
123+
qscheme=torch.per_tensor_symmetric, # Change the activation quant config to symmetric
124+
is_dynamic=False,
125+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
126+
**extra_args
127+
),
128+
)
129+
130+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
131+
PerChannelMinMaxObserver
132+
)
133+
134+
weight_quantization_spec = QuantizationSpec(
135+
dtype=torch.int8,
136+
quant_min=-128,
137+
quant_max=127,
138+
qscheme=torch.per_channel_symmetric, # Same as the default config, the only supported option for weight
139+
ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
140+
is_dynamic=False,
141+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
142+
**extra_args
143+
),
144+
)
145+
146+
bias_quantization_spec = None # will use placeholder observer by default
147+
quantization_config = QuantizationConfig(
148+
act_quantization_spec,
149+
act_quantization_spec,
150+
weight_quantization_spec,
151+
bias_quantization_spec,
152+
False,
153+
)
154+
return quantization_config
155+
156+
# Then, set the quantization configuration to the quantizer.
157+
quantizer = XPUInductorQuantizer()
158+
quantizer.set_global(get_xpu_inductor_symm_quantization_config())
159+
160+
After the backend-specific quantizer is imported, prepare the model for post-training quantization.
161+
``prepare_pt2e`` folds ``BatchNorm`` operators into preceding Conv2d operators, and inserts observers into appropriate places in the model.
162+
163+
::
164+
165+
prepared_model = prepare_pt2e(exported_model, quantizer)
166+
167+
**(For static quantization only)** Calibrate the ``prepared_model`` after the observers are inserted into the model.
168+
169+
::
170+
171+
# We use the dummy data as an example here
172+
prepared_model(*example_inputs)
173+
174+
# Alternatively: user can define the dataset to calibrate
175+
# def calibrate(model, data_loader):
176+
# model.eval()
177+
# with torch.no_grad():
178+
# for image, target in data_loader:
179+
# model(image)
180+
# calibrate(prepared_model, data_loader_test) # run calibration on sample data
181+
182+
Finally, convert the calibrated model to a quantized model. ``convert_pt2e`` takes a calibrated model and produces a quantized model.
183+
184+
::
185+
186+
converted_model = convert_pt2e(prepared_model)
187+
188+
After these steps, the quantization flow has been completed and the quantized model is available.
189+
190+
191+
3. Lower into Inductor
192+
^^^^^^^^^^^^^^^^^^^^^^^^
193+
194+
The quantized model will then be lowered into the inductor backend.
195+
196+
::
197+
198+
with torch.no_grad():
199+
optimized_model = torch.compile(converted_model)
200+
201+
# Running some benchmark
202+
optimized_model(*example_inputs)
203+
204+
In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In this instance,
205+
a convolution or GEMM operator produces the output in BFloat16 instead of Float32 in the absence
206+
of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through
207+
subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance.
208+
The utilization of this feature mirrors that of regular BFloat16 Autocast, as simple as wrapping the
209+
script within the BFloat16 Autocast context.
210+
211+
::
212+
213+
with torch.amp.autocast(device_type="xpu", dtype=torch.bfloat16), torch.no_grad():
214+
# Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into indcutor backend,
215+
# For operators such as QConvolution and QLinear:
216+
# * The input data type is consistently defined as int8, attributable to the presence of a pair
217+
# of quantization and dequantization nodes inserted at the input.
218+
# * The computation precision remains at int8.
219+
# * The output data type may vary, being either int8 or BFloat16, contingent on the presence
220+
# of a pair of quantization and dequantization nodes at the output.
221+
# For non-quantizable pointwise operators, the data type will be inherited from the previous node,
222+
# potentially resulting in a data type of BFloat16 in this scenario.
223+
# For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8
224+
# data type for both input and output.
225+
optimized_model = torch.compile(converted_model)
226+
227+
# Running some benchmark
228+
optimized_model(*example_inputs)
229+
230+
231+
Conclusion
232+
------------
233+
234+
In this tutorial, we have learned how to utilize the ``XPUInductorQuantizer`` to perform post-training quantization on models for inference
235+
on Intel GPUs, leveraging PyTorch 2's Export Quantization flow. We covered the step-by-step process of capturing an FX Graph,
236+
applying quantization, and lowering the quantized model into the inductor backend using ``torch.compile``. Additionally, we explored
237+
the benefits of using int8-mixed-bf16 quantization for improved memory efficiency and potential performance gains,
238+
especially when using ``BFloat16`` autocast.

0 commit comments

Comments
 (0)