-
Notifications
You must be signed in to change notification settings - Fork 4.1k
PyTorch 2 Export Quantization for OpenVINO torch.compile backend #3321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f0ab805
acf1647
5b1c99a
b2eaa82
810899a
82a47a5
26f044b
75d3549
e8e94d3
f09a85f
2c766e7
b424f92
f3137be
b7d2781
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
PyTorch 2 Export Quantization for OpenVINO torch.compile Backend. | ||
=========================================================================== | ||
|
||
**Authors**: `Daniil Lyakhov <https://github.com/daniil-lyakhov>`_, `Aamir Nazir <https://github.com/anzr299>`_, `Alexander Suslov <https://github.com/alexsu52>`_, `Yamini Nimmagadda <https://github.com/ynimmaga>`_, `Alexander Kozlov <https://github.com/AlexKoff88>`_ | ||
|
||
Prerequisites | ||
-------------- | ||
- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_ | ||
- `How to Write a Quantizer for PyTorch 2 Export Quantization <https://pytorch.org/tutorials/prototype/pt2e_quantizer.html>`_ | ||
|
||
Introduction | ||
-------------- | ||
|
||
.. note:: | ||
|
||
This is an experimental feature, the quantization API is subject to change. | ||
|
||
This tutorial demonstrates how to use ``OpenVINOQuantizer`` from `Neural Network Compression Framework (NNCF) <https://github.com/openvinotoolkit/nncf/tree/develop>`_ in PyTorch 2 Export Quantization flow to generate a quantized model customized for the `OpenVINO torch.compile backend <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`_ and explains how to lower the quantized model into the `OpenVINO <https://docs.openvino.ai/2024/index.html>`_ representation. | ||
``OpenVINOQuantizer`` unlocks the full potential of low-precision OpenVINO kernels due to the placement of quantizers designed specifically for the OpenVINO. | ||
|
||
The PyTorch 2 export quantization flow uses ``torch.export`` to capture the model into a graph and performs quantization transformations on top of the ATen graph. | ||
This approach is expected to have significantly higher model coverage, improved flexibility, and a simplified UX. | ||
OpenVINO backend compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. | ||
|
||
The quantization flow mainly includes four steps: | ||
|
||
- Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_. | ||
- Step 2: Apply the PyTorch 2 Export Quantization flow with OpenVINOQuantizer based on the captured FX Graph. | ||
- Step 3: Lower the quantized model into OpenVINO representation with the `torch.compile <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`_ API. | ||
- Optional step 4: : Improve quantized model metrics via `quantize_pt2e <https://openvinotoolkit.github.io/nncf/autoapi/nncf/experimental/torch/fx/index.html#nncf.experimental.torch.fx.quantize_pt2e>`_ method. | ||
|
||
The high-level architecture of this flow could look like this: | ||
|
||
:: | ||
|
||
float_model(Python) Example Input | ||
\ / | ||
\ / | ||
—-------------------------------------------------------- | ||
| export | | ||
—-------------------------------------------------------- | ||
| | ||
FX Graph in ATen | ||
| | ||
| OpenVINOQuantizer | ||
| / | ||
—-------------------------------------------------------- | ||
| prepare_pt2e | | ||
| | | | ||
| Calibrate | ||
| | | | ||
| convert_pt2e | | ||
—-------------------------------------------------------- | ||
| | ||
Quantized Model | ||
| | ||
—-------------------------------------------------------- | ||
| Lower into Inductor | | ||
—-------------------------------------------------------- | ||
| | ||
OpenVINO model | ||
|
||
Post Training Quantization | ||
---------------------------- | ||
|
||
Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_ | ||
for post training quantization. | ||
|
||
Prerequisite: OpenVINO and NNCF installation | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
OpenVINO and NNCF could be easily installed via `pip distribution <https://docs.openvino.ai/2024/get-started/install-openvino.html>`_: | ||
|
||
.. code-block:: bash | ||
|
||
pip install -U pip | ||
pip install openvino, nncf | ||
|
||
|
||
1. Capture FX Graph | ||
^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
We will start by performing the necessary imports, capturing the FX Graph from the eager module. | ||
|
||
.. code-block:: python | ||
|
||
import copy | ||
import openvino.torch | ||
import torch | ||
import torchvision.models as models | ||
from torch.ao.quantization.quantize_pt2e import convert_pt2e | ||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e | ||
Comment on lines
+90
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for long term support, importing from this requires people to install torchao nightly though: but this can be done in a separate step since you might need to adapt your code to work with the torchao copy |
||
|
||
import nncf.torch | ||
|
||
# Create the Eager Model | ||
model_name = "resnet18" | ||
model = models.__dict__[model_name](pretrained=True) | ||
|
||
# Set the model to eval mode | ||
model = model.eval() | ||
|
||
# Create the data, using the dummy data here as an example | ||
traced_bs = 50 | ||
x = torch.randn(traced_bs, 3, 224, 224) | ||
example_inputs = (x,) | ||
|
||
# Capture the FX Graph to be quantized | ||
with torch.no_grad(), nncf.torch.disable_patching(): | ||
exported_model = torch.export.export(model, example_inputs).module() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Edit: looks like in new pytorch versions these two are the same, in that case we might be recommending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK I checked with Tugsuu, please continue to use |
||
|
||
|
||
|
||
2. Apply Quantization | ||
^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
After we capture the FX Module to be quantized, we will import the OpenVINOQuantizer. | ||
|
||
|
||
.. code-block:: python | ||
|
||
from nncf.experimental.torch.fx import OpenVINOQuantizer | ||
|
||
quantizer = OpenVINOQuantizer() | ||
|
||
``OpenVINOQuantizer`` has several optional parameters that allow tuning the quantization process to get a more accurate model. | ||
Below is the list of essential parameters and their description: | ||
|
||
|
||
* ``preset`` - defines quantization scheme for the model. Two types of presets are available: | ||
|
||
* ``PERFORMANCE`` (default) - defines symmetric quantization of weights and activations | ||
|
||
* ``MIXED`` - weights are quantized with symmetric quantization and the activations are quantized with asymmetric quantization. This preset is recommended for models with non-ReLU and asymmetric activation functions, e.g. ELU, PReLU, GELU, etc. | ||
|
||
.. code-block:: python | ||
|
||
OpenVINOQuantizer(preset=nncf.QuantizationPreset.MIXED) | ||
|
||
* ``model_type`` - used to specify quantization scheme required for specific type of the model. Transformer is the only supported special quantization scheme to preserve accuracy after quantization of Transformer models (BERT, Llama, etc.). None is default, i.e. no specific scheme is defined. | ||
|
||
.. code-block:: python | ||
|
||
OpenVINOQuantizer(model_type=nncf.ModelType.Transformer) | ||
|
||
* ``ignored_scope`` - this parameter can be used to exclude some layers from the quantization process to preserve the model accuracy. For example, when you want to exclude the last layer of the model from quantization. Below are some examples of how to use this parameter: | ||
|
||
.. code-block:: python | ||
|
||
#Exclude by layer name: | ||
names = ['layer_1', 'layer_2', 'layer_3'] | ||
OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(names=names)) | ||
|
||
#Exclude by layer type: | ||
types = ['Conv2d', 'Linear'] | ||
OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(types=types)) | ||
|
||
#Exclude by regular expression: | ||
regex = '.*layer_.*' | ||
OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(patterns=regex)) | ||
|
||
#Exclude by subgraphs: | ||
# In this case, all nodes along all simple paths in the graph | ||
# from input to output nodes will be excluded from the quantization process. | ||
subgraph = nncf.Subgraph(inputs=['layer_1', 'layer_2'], outputs=['layer_3']) | ||
OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(subgraphs=[subgraph])) | ||
|
||
|
||
* ``target_device`` - defines the target device, the specificity of which will be taken into account during optimization. The following values are supported: ``ANY`` (default), ``CPU``, ``CPU_SPR``, ``GPU``, and ``NPU``. | ||
|
||
.. code-block:: python | ||
|
||
OpenVINOQuantizer(target_device=nncf.TargetDevice.CPU) | ||
|
||
For further details on `OpenVINOQuantizer` please see the `documentation <https://openvinotoolkit.github.io/nncf/autoapi/nncf/experimental/torch/fx/index.html#nncf.experimental.torch.fx.OpenVINOQuantizer>`_. | ||
|
||
After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. | ||
``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. | ||
|
||
.. code-block:: python | ||
|
||
prepared_model = prepare_pt2e(exported_model, quantizer) | ||
|
||
Now, we will calibrate the ``prepared_model`` after the observers are inserted in the model. | ||
|
||
.. code-block:: python | ||
|
||
# We use the dummy data as an example here | ||
prepared_model(*example_inputs) | ||
|
||
Finally, we will convert the calibrated Model to a quantized Model. ``convert_pt2e`` takes a calibrated model and produces a quantized model. | ||
|
||
.. code-block:: python | ||
|
||
quantized_model = convert_pt2e(prepared_model, fold_quantize=False) | ||
|
||
After these steps, we finished running the quantization flow, and we will get the quantized model. | ||
|
||
|
||
3. Lower into OpenVINO representation | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
After that the FX Graph can utilize OpenVINO optimizations using `torch.compile(…, backend=”openvino”) <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`_ functionality. | ||
|
||
.. code-block:: python | ||
|
||
with torch.no_grad(), nncf.torch.disable_patching(): | ||
optimized_model = torch.compile(quantized_model, backend="openvino") | ||
|
||
# Running some benchmark | ||
optimized_model(*example_inputs) | ||
|
||
|
||
|
||
The optimized model is using low-level kernels designed specifically for Intel CPU. | ||
This should significantly speed up inference time in comparison with the eager model. | ||
|
||
4. Optional: Improve quantized model metrics | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
NNCF implements advanced quantization algorithms like `SmoothQuant <https://arxiv.org/abs/2211.10438>`_ and `BiasCorrection <https://arxiv.org/abs/1906.04721>`_, which help | ||
to improve the quantized model metrics while minimizing the output discrepancies between the original and compressed models. | ||
These advanced NNCF algorithms can be accessed via the NNCF `quantize_pt2e` API: | ||
|
||
.. code-block:: python | ||
|
||
from nncf.experimental.torch.fx import quantize_pt2e | ||
|
||
calibration_loader = torch.utils.data.DataLoader(...) | ||
|
||
|
||
def transform_fn(data_item): | ||
images, _ = data_item | ||
return images | ||
|
||
|
||
calibration_dataset = nncf.Dataset(calibration_loader, transform_fn) | ||
quantized_model = quantize_pt2e( | ||
exported_model, quantizer, calibration_dataset, smooth_quant=True, fast_bias_correction=False | ||
) | ||
|
||
|
||
For further details, please see the `documentation <https://openvinotoolkit.github.io/nncf/autoapi/nncf/experimental/torch/fx/index.html#nncf.experimental.torch.fx.quantize_pt2e>`_ | ||
and a complete `example on Resnet18 quantization <https://github.com/openvinotoolkit/nncf/blob/develop/examples/post_training_quantization/torch_fx/resnet18/README.md>`_. | ||
|
||
Conclusion | ||
------------ | ||
|
||
This tutorial introduces how to use torch.compile with the OpenVINO backend and the OpenVINO quantizer. | ||
For more details on NNCF and the NNCF Quantization Flow for PyTorch models, refer to the `NNCF Quantization Guide <https://docs.openvino.ai/2025/openvino-workflow/model-optimization-guide/quantizing-models-post-training/basic-quantization-flow.html.>`_. | ||
For additional information, check out the `OpenVINO Deployment via torch.compile Documentation <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add this file to prototype_source/prototype_index.rst toctree and add a carditem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, done