From f0ab80548b462668c5451dbb917004f419360b30 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 21 Nov 2024 18:39:12 +0100 Subject: [PATCH 01/11] WIP --- .../openvino_nncf_quantization.rst | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 prototype_source/openvino_nncf_quantization.rst diff --git a/prototype_source/openvino_nncf_quantization.rst b/prototype_source/openvino_nncf_quantization.rst new file mode 100644 index 00000000000..94c6968eaf5 --- /dev/null +++ b/prototype_source/openvino_nncf_quantization.rst @@ -0,0 +1,221 @@ +PyTorch 2 Export Quantization with NNCF quantization and OpenVINO runtime +=========================================================================== + +**Author**: dlyakhov, asuslov, aamir, # TODO: add required authors + +Introduction +-------------- + +This tutorial introduces the steps for utilizing the `Neural Network Compression Framework (nncf) `_ to generate a quantized model customized +for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. + +The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and perform quantization transformations on top of the ATen graph. +This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. +OpenVINO is the new backend that compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. + +The quantization flow mainly includes three steps: + +- Step 1: OpenVINO and NNCF installation. +- Step 2: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. +- Step 3: Apply the Quantization flow based on the captured FX Graph. +- Step 4: Lower the quantized model into OpenVINO representation with the API ``torch.compile``. + +The high-level architecture of this flow could look like this: + +:: + + float_model(Python) Example Input + \ / + \ / + —-------------------------------------------------------- + | export | + —-------------------------------------------------------- + | + FX Graph in ATen + | + | + —-------------------------------------------------------- + | nncf.quantize | + —-------------------------------------------------------- + | + Quantized Model + | + —-------------------------------------------------------- + | Lower into Inductor | + —-------------------------------------------------------- + | + OpenVINO model + +Post Training Quantization +---------------------------- + +Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model `_ +for post training quantization. + +1. OpenVINO and NNCF installation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +OpenVINO and NNCF could be easily installed via `pip distribution `_: + +.. code-block:: bash + + pip install -U pip + pip install openvino, nncf + + +2. Capture FX Graph +^^^^^^^^^^^^^^^^^^^^^ + +We will start by performing the necessary imports, capturing the FX Graph from the eager module. + +.. code-block:: python + + import torch + import torchvision.models as models + import copy + import openvino.torch + + import nncf + from nncf.torch import disable_patching + + # Create the Eager Model + model_name = "resnet18" + model = models.__dict__[model_name](pretrained=True) + + # Set the model to eval mode + model = model.eval() + + # Create the data, using the dummy data here as an example + traced_bs = 50 + x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) + example_inputs = (x,) + + # Capture the FX Graph to be quantized + with torch.no_grad(): + with disable_patching(): + exported_model = torch.export.export(model, example_inputs).module() + + +Next, we will have the FX Module to be quantized. + +3. Apply Quantization +^^^^^^^^^^^^^^^^^^^^^^^ + +Before the quantization, we need to create an instance of the nncf.Dataset class that represents the calibration dataset. +The ``nncf.Dataset`` class can be a wrapper over the framework dataset object that is used for model training or validation +The class constructor receives the dataset object and an optional transformation function. + +The transformation function is a function that takes a sample from the dataset and returns data that can be passed to the model for inference. +For example, this function can take a tuple of a data tensor and labels tensor and return the former while ignoring the latter. +The transformation function is used to avoid modifying the dataset code to make it compatible with the quantization API. +The function is applied to each sample from the dataset before passing it to the model for inference. +The following code snippet shows how to create an instance of the ``nncf.Dataset`` class: + +.. code-block:: python + + calibration_loader = torch.utils.data.DataLoader([example_inputs]) + + def transform_fn(data_item): + # In the transformation function, + # user can separate labels and input data + # from the given data item: + # images, _ = data_item + return data_item + + calibration_dataset = nncf.Dataset(calibration_loader, transform_fn) + +If there is no framework dataset object, you can create your own entity that implements the Iterable interface in Python, +for example, the list of images, and returns data samples feasible for inference. In this case, a transformation function is not required. + +Once the dataset is ready and the model object is instantiated, you can apply 8-bit quantization to it. + +.. code-block:: python + + with disable_patching(): + quantized_model = nncf.quantize(exported_model, calibration_dataset) + +``nncf.quantize()`` function has several optional parameters that allow tuning the quantization process to get a more accurate model. +Below is the list of parameters and their description: + +* ``model_type`` - used to specify quantization scheme required for specific type of the model. Transformer is the only supported special quantization scheme to preserve accuracy after quantization of Transformer models (BERT, DistilBERT, etc.). None is default, i.e. no specific scheme is defined. +.. code-block:: python + + nncf.quantize(model, dataset, model_type=nncf.ModelType.Transformer) + +* ``preset`` - defines quantization scheme for the model. Two types of presets are available: + + * ``PERFORMANCE`` (default) - defines symmetric quantization of weights and activations + + * ``MIXED`` - weights are quantized with symmetric quantization and the activations are quantized with asymmetric quantization. This preset is recommended for models with non-ReLU and asymmetric activation functions, e.g. ELU, PReLU, GELU, etc. + +.. code-block:: python + + nncf.quantize(model, dataset, preset=nncf.QuantizationPreset.MIXED) + +* ``fast_bias_correction`` - when set to False, enables a more accurate bias (error) correction algorithm that can be used to improve the accuracy of the model. True is used by default to minimize quantization time. + +.. code-block:: python + + nncf.quantize(model, dataset, fast_bias_correction=False) + +* ``subset_size`` - defines the number of samples from the calibration dataset that will be used to estimate quantization parameters of activations. The default value is 300. + +.. code-block:: python + + nncf.quantize(model, dataset, subset_size=1000) + +* ``ignored_scope`` - this parameter can be used to exclude some layers from the quantization process to preserve the model accuracy. For example, when you want to exclude the last layer of the model from quantization. Below are some examples of how to use this parameter: + +.. code-block:: python + + #Exclude by layer name: + names = ['layer_1', 'layer_2', 'layer_3'] + nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(names=names)) + + #Exclude by layer type: + types = ['Conv2d', 'Linear'] + nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(types=types)) + + #Exclude by regular expression: + regex = '.*layer_.*' + nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(patterns=regex)) + + #Exclude by subgraphs: + # In this case, all nodes along all simple paths in the graph + # from input to output nodes will be excluded from the quantization process. + subgraph = nncf.Subgraph(inputs=['layer_1', 'layer_2'], outputs=['layer_3']) + nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(subgraphs=[subgraph])) + + +* ``target_device`` - defines the target device, the specificity of which will be taken into account during optimization. The following values are supported: ``ANY`` (default), ``CPU``, ``CPU_SPR``, ``GPU``, and ``NPU``. + +.. code-block:: python + + nncf.quantize(model, dataset, target_device=nncf.TargetDevice.CPU) + +* ``advanced_parameters`` - used to specify advanced quantization parameters for fine-tuning the quantization algorithm. Defined by nncf.quantization.advanced_parameters NNCF submodule. None is default. + +After these steps, we finished running the quantization flow, and we will get the quantized model. + + +4. Lower into OpenVINO representation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +After that the FX Graph can utilize OpenVINO optimizations using `torch.compile(…, backend=”openvino”) `_ functionality. + +.. code-block:: python + + with torch.no_grad(): + optimized_model = torch.compile(quantized_model, backend="openvino") + + # Running some benchmark + optimized_model(*example_inputs) + + +The optimized model is using low-level kernels designed specifically for Intel CPU. +This should significantly speed up inference time in comparison with the eager model. + +Conclusion +------------ + +With this tutorial, we introduce how to use torch.compile with the OpenVINO backend with models quantized via ``nncf.quantize``. +For further information, please visit `complete example on renset18 model `_. From acf1647329e72de081a4035f42b74fadc311cd44 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 28 Jan 2025 16:14:10 +0100 Subject: [PATCH 02/11] OpenVINOQuantizer --- ...uantization.rst => openvino_quantizer.rst} | 134 ++++++++---------- 1 file changed, 62 insertions(+), 72 deletions(-) rename prototype_source/{openvino_nncf_quantization.rst => openvino_quantizer.rst} (55%) diff --git a/prototype_source/openvino_nncf_quantization.rst b/prototype_source/openvino_quantizer.rst similarity index 55% rename from prototype_source/openvino_nncf_quantization.rst rename to prototype_source/openvino_quantizer.rst index 94c6968eaf5..3041223b013 100644 --- a/prototype_source/openvino_nncf_quantization.rst +++ b/prototype_source/openvino_quantizer.rst @@ -1,4 +1,4 @@ -PyTorch 2 Export Quantization with NNCF quantization and OpenVINO runtime +PyTorch 2 Export Quantization with OpenVINO backend =========================================================================== **Author**: dlyakhov, asuslov, aamir, # TODO: add required authors @@ -9,13 +9,13 @@ Introduction This tutorial introduces the steps for utilizing the `Neural Network Compression Framework (nncf) `_ to generate a quantized model customized for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. -The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and perform quantization transformations on top of the ATen graph. +The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and performs quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. OpenVINO is the new backend that compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. -The quantization flow mainly includes three steps: +The quantization flow mainly includes four steps: -- Step 1: OpenVINO and NNCF installation. +- Step 1: Install OpenVINO and NNCF. - Step 2: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. - Step 3: Apply the Quantization flow based on the captured FX Graph. - Step 4: Lower the quantized model into OpenVINO representation with the API ``torch.compile``. @@ -33,9 +33,14 @@ The high-level architecture of this flow could look like this: | FX Graph in ATen | - | + | OpenVINOQuantizer + | / —-------------------------------------------------------- - | nncf.quantize | + | prepare_pt2e | + | | | + | Calibrate + | | | + | convert_pt2e | —-------------------------------------------------------- | Quantized Model @@ -69,10 +74,13 @@ We will start by performing the necessary imports, capturing the FX Graph from t .. code-block:: python - import torch - import torchvision.models as models import copy import openvino.torch + import torch + import torchvision.models as models + from torch.ao.quantization.quantize_pt2e import convert_pt2e + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch.ao.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer import nncf from nncf.torch import disable_patching @@ -90,109 +98,90 @@ We will start by performing the necessary imports, capturing the FX Graph from t example_inputs = (x,) # Capture the FX Graph to be quantized - with torch.no_grad(): - with disable_patching(): - exported_model = torch.export.export(model, example_inputs).module() + with torch.no_grad(), disable_patching(): + exported_model = torch.export.export(model, example_inputs).module() -Next, we will have the FX Module to be quantized. 3. Apply Quantization ^^^^^^^^^^^^^^^^^^^^^^^ -Before the quantization, we need to create an instance of the nncf.Dataset class that represents the calibration dataset. -The ``nncf.Dataset`` class can be a wrapper over the framework dataset object that is used for model training or validation -The class constructor receives the dataset object and an optional transformation function. +After we capture the FX Module to be quantized, we will import the OpenVINOQuantizer. -The transformation function is a function that takes a sample from the dataset and returns data that can be passed to the model for inference. -For example, this function can take a tuple of a data tensor and labels tensor and return the former while ignoring the latter. -The transformation function is used to avoid modifying the dataset code to make it compatible with the quantization API. -The function is applied to each sample from the dataset before passing it to the model for inference. -The following code snippet shows how to create an instance of the ``nncf.Dataset`` class: .. code-block:: python - calibration_loader = torch.utils.data.DataLoader([example_inputs]) + quantizer = OpenVINOQuantizer() - def transform_fn(data_item): - # In the transformation function, - # user can separate labels and input data - # from the given data item: - # images, _ = data_item - return data_item +``OpenVINOQuantizer`` has several optional parameters that allow tuning the quantization process to get a more accurate model. +Below is the list of essential parameters and their description: - calibration_dataset = nncf.Dataset(calibration_loader, transform_fn) -If there is no framework dataset object, you can create your own entity that implements the Iterable interface in Python, -for example, the list of images, and returns data samples feasible for inference. In this case, a transformation function is not required. +* ``preset`` - defines quantization scheme for the model. Two types of presets are available: -Once the dataset is ready and the model object is instantiated, you can apply 8-bit quantization to it. + * ``PERFORMANCE`` (default) - defines symmetric quantization of weights and activations -.. code-block:: python + * ``MIXED`` - weights are quantized with symmetric quantization and the activations are quantized with asymmetric quantization. This preset is recommended for models with non-ReLU and asymmetric activation functions, e.g. ELU, PReLU, GELU, etc. - with disable_patching(): - quantized_model = nncf.quantize(exported_model, calibration_dataset) + .. code-block:: python -``nncf.quantize()`` function has several optional parameters that allow tuning the quantization process to get a more accurate model. -Below is the list of parameters and their description: + OpenVINOQuantizer(preset=nncf.QuantizationPreset.MIXED) * ``model_type`` - used to specify quantization scheme required for specific type of the model. Transformer is the only supported special quantization scheme to preserve accuracy after quantization of Transformer models (BERT, DistilBERT, etc.). None is default, i.e. no specific scheme is defined. -.. code-block:: python - nncf.quantize(model, dataset, model_type=nncf.ModelType.Transformer) + .. code-block:: python -* ``preset`` - defines quantization scheme for the model. Two types of presets are available: + OpenVINOQuantizer(model_type=nncf.ModelType.Transformer) - * ``PERFORMANCE`` (default) - defines symmetric quantization of weights and activations +* ``ignored_scope`` - this parameter can be used to exclude some layers from the quantization process to preserve the model accuracy. For example, when you want to exclude the last layer of the model from quantization. Below are some examples of how to use this parameter: - * ``MIXED`` - weights are quantized with symmetric quantization and the activations are quantized with asymmetric quantization. This preset is recommended for models with non-ReLU and asymmetric activation functions, e.g. ELU, PReLU, GELU, etc. + .. code-block:: python -.. code-block:: python + #Exclude by layer name: + names = ['layer_1', 'layer_2', 'layer_3'] + OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(names=names)) - nncf.quantize(model, dataset, preset=nncf.QuantizationPreset.MIXED) + #Exclude by layer type: + types = ['Conv2d', 'Linear'] + OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(types=types)) -* ``fast_bias_correction`` - when set to False, enables a more accurate bias (error) correction algorithm that can be used to improve the accuracy of the model. True is used by default to minimize quantization time. + #Exclude by regular expression: + regex = '.*layer_.*' + OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(patterns=regex)) -.. code-block:: python + #Exclude by subgraphs: + # In this case, all nodes along all simple paths in the graph + # from input to output nodes will be excluded from the quantization process. + subgraph = nncf.Subgraph(inputs=['layer_1', 'layer_2'], outputs=['layer_3']) + OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(subgraphs=[subgraph])) - nncf.quantize(model, dataset, fast_bias_correction=False) -* ``subset_size`` - defines the number of samples from the calibration dataset that will be used to estimate quantization parameters of activations. The default value is 300. +* ``target_device`` - defines the target device, the specificity of which will be taken into account during optimization. The following values are supported: ``ANY`` (default), ``CPU``, ``CPU_SPR``, ``GPU``, and ``NPU``. -.. code-block:: python + .. code-block:: python - nncf.quantize(model, dataset, subset_size=1000) + OpenVINOQuantizer(target_device=nncf.TargetDevice.CPU) -* ``ignored_scope`` - this parameter can be used to exclude some layers from the quantization process to preserve the model accuracy. For example, when you want to exclude the last layer of the model from quantization. Below are some examples of how to use this parameter: -.. code-block:: python +After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. +``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. - #Exclude by layer name: - names = ['layer_1', 'layer_2', 'layer_3'] - nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(names=names)) +.. code-block:: python - #Exclude by layer type: - types = ['Conv2d', 'Linear'] - nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(types=types)) + prepared_model = prepare_pt2e(exported_model, quantizer) - #Exclude by regular expression: - regex = '.*layer_.*' - nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(patterns=regex)) +Now, we will calibrate the ``prepared_model`` after the observers are inserted in the model. - #Exclude by subgraphs: - # In this case, all nodes along all simple paths in the graph - # from input to output nodes will be excluded from the quantization process. - subgraph = nncf.Subgraph(inputs=['layer_1', 'layer_2'], outputs=['layer_3']) - nncf.quantize(model, dataset, ignored_scope=nncf.IgnoredScope(subgraphs=[subgraph])) +.. code-block:: python + # We use the dummy data as an example here + prepared_model(*example_inputs) -* ``target_device`` - defines the target device, the specificity of which will be taken into account during optimization. The following values are supported: ``ANY`` (default), ``CPU``, ``CPU_SPR``, ``GPU``, and ``NPU``. +Finally, we will convert the calibrated Model to a quantized Model. ``convert_pt2e`` takes a calibrated model and produces a quantized model. .. code-block:: python - nncf.quantize(model, dataset, target_device=nncf.TargetDevice.CPU) - -* ``advanced_parameters`` - used to specify advanced quantization parameters for fine-tuning the quantization algorithm. Defined by nncf.quantization.advanced_parameters NNCF submodule. None is default. + quantized_model = convert_pt2e(prepared_model) After these steps, we finished running the quantization flow, and we will get the quantized model. @@ -204,18 +193,19 @@ After that the FX Graph can utilize OpenVINO optimizations using `torch.compile( .. code-block:: python - with torch.no_grad(): + with torch.no_grad(), disable_patching(): optimized_model = torch.compile(quantized_model, backend="openvino") # Running some benchmark optimized_model(*example_inputs) + The optimized model is using low-level kernels designed specifically for Intel CPU. This should significantly speed up inference time in comparison with the eager model. Conclusion ------------ -With this tutorial, we introduce how to use torch.compile with the OpenVINO backend with models quantized via ``nncf.quantize``. -For further information, please visit `complete example on renset18 model `_. +With this tutorial, we introduce how to use torch.compile with the OpenVINO backend and the OpenVINO quantizer. +For further information, please visit `OpenVINO deploymet via torch.compile documentation `_. From 5b1c99aaaddb6b25e13a39cef6ecb99265ad693c Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 7 Feb 2025 13:14:10 +0100 Subject: [PATCH 03/11] Apply suggestions from code review Co-authored-by: Alexander Suslov Co-authored-by: Yamini Nimmagadda --- prototype_source/openvino_quantizer.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index 3041223b013..8bd0cba6245 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -1,4 +1,4 @@ -PyTorch 2 Export Quantization with OpenVINO backend +PyTorch 2 Export Quantization for OpenVINO torch.compile backend. =========================================================================== **Author**: dlyakhov, asuslov, aamir, # TODO: add required authors @@ -6,19 +6,18 @@ PyTorch 2 Export Quantization with OpenVINO backend Introduction -------------- -This tutorial introduces the steps for utilizing the `Neural Network Compression Framework (nncf) `_ to generate a quantized model customized -for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. +This tutorial demonstrates how to use `OpenVINOQuantizer` from `Neural Network Compression Framework (NNCF) `_ in PyTorch 2 Export Quantization flow to generate a quantized model customized for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and performs quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. -OpenVINO is the new backend that compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. +OpenVINO backend compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. The quantization flow mainly includes four steps: - Step 1: Install OpenVINO and NNCF. - Step 2: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. -- Step 3: Apply the Quantization flow based on the captured FX Graph. -- Step 4: Lower the quantized model into OpenVINO representation with the API ``torch.compile``. +- Step 3: Apply the PyTorch 2 Export Quantization flow with OpenVINOQuantizer based on the captured FX Graph. +- Step 4: Lower the quantized model into OpenVINO representation with the API `torch.compile `_. The high-level architecture of this flow could look like this: @@ -80,7 +79,6 @@ We will start by performing the necessary imports, capturing the FX Graph from t import torchvision.models as models from torch.ao.quantization.quantize_pt2e import convert_pt2e from torch.ao.quantization.quantize_pt2e import prepare_pt2e - from torch.ao.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer import nncf from nncf.torch import disable_patching @@ -111,6 +109,8 @@ After we capture the FX Module to be quantized, we will import the OpenVINOQuant .. code-block:: python + from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer + quantizer = OpenVINOQuantizer() ``OpenVINOQuantizer`` has several optional parameters that allow tuning the quantization process to get a more accurate model. @@ -208,4 +208,4 @@ Conclusion ------------ With this tutorial, we introduce how to use torch.compile with the OpenVINO backend and the OpenVINO quantizer. -For further information, please visit `OpenVINO deploymet via torch.compile documentation `_. +For further information, please visit `OpenVINO deployment via torch.compile documentation `_. From b2eaa82c03d6d2dea960afdff6a27893a6ba2bb7 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 7 Feb 2025 13:37:55 +0100 Subject: [PATCH 04/11] Comments --- prototype_source/openvino_quantizer.rst | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index 8bd0cba6245..7ac1f7e53b6 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -3,6 +3,11 @@ PyTorch 2 Export Quantization for OpenVINO torch.compile backend. **Author**: dlyakhov, asuslov, aamir, # TODO: add required authors +Prerequisites +-------------- +- [PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html) +- [How to Write a Quantizer for PyTorch 2 Export Quantization](https://pytorch.org/tutorials/prototype/pt2e_quantizer.html) + Introduction -------------- @@ -80,8 +85,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t from torch.ao.quantization.quantize_pt2e import convert_pt2e from torch.ao.quantization.quantize_pt2e import prepare_pt2e - import nncf - from nncf.torch import disable_patching + import nncf.torch # Create the Eager Model model_name = "resnet18" @@ -92,11 +96,11 @@ We will start by performing the necessary imports, capturing the FX Graph from t # Create the data, using the dummy data here as an example traced_bs = 50 - x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) + x = torch.randn(traced_bs, 3, 224, 224) example_inputs = (x,) # Capture the FX Graph to be quantized - with torch.no_grad(), disable_patching(): + with torch.no_grad(), nncf.torch.disable_patching(): exported_model = torch.export.export(model, example_inputs).module() @@ -193,7 +197,7 @@ After that the FX Graph can utilize OpenVINO optimizations using `torch.compile( .. code-block:: python - with torch.no_grad(), disable_patching(): + with torch.no_grad(), nncf.torch.disable_patching(): optimized_model = torch.compile(quantized_model, backend="openvino") # Running some benchmark @@ -207,5 +211,6 @@ This should significantly speed up inference time in comparison with the eager m Conclusion ------------ -With this tutorial, we introduce how to use torch.compile with the OpenVINO backend and the OpenVINO quantizer. -For further information, please visit `OpenVINO deployment via torch.compile documentation `_. +This tutorial introduces how to use torch.compile with the OpenVINO backend and the OpenVINO quantizer. +For more details on NNCF and the NNCF Quantization Flow for PyTorch models, refer to the `NNCF Quantization Guide `_. +For additional information, check out the `OpenVINO Deployment via torch.compile Documentation `_. From 810899a39ca3ebbee7b5d2df1701b29d8ba707e6 Mon Sep 17 00:00:00 2001 From: daniil-lyakhov Date: Thu, 20 Feb 2025 18:35:29 +0100 Subject: [PATCH 05/11] NNCF API docs --- prototype_source/openvino_quantizer.rst | 37 ++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index 7ac1f7e53b6..90783312453 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -1,12 +1,12 @@ PyTorch 2 Export Quantization for OpenVINO torch.compile backend. =========================================================================== -**Author**: dlyakhov, asuslov, aamir, # TODO: add required authors +**Authors**: `Daniil Lyakhov `_, `Alexander Suslov `_, `Aamir Nazir `_ Prerequisites -------------- -- [PyTorch 2 Export Post Training Quantization](https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html) -- [How to Write a Quantizer for PyTorch 2 Export Quantization](https://pytorch.org/tutorials/prototype/pt2e_quantizer.html) +- `PyTorch 2 Export Post Training Quantization `_ +- `How to Write a Quantizer for PyTorch 2 Export Quantization `_ Introduction -------------- @@ -113,7 +113,7 @@ After we capture the FX Module to be quantized, we will import the OpenVINOQuant .. code-block:: python - from nncf.experimental.torch.fx.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer + from nncf.experimental.torch.fx import OpenVINOQuantizer quantizer = OpenVINOQuantizer() @@ -166,6 +166,7 @@ Below is the list of essential parameters and their description: OpenVINOQuantizer(target_device=nncf.TargetDevice.CPU) +For futher details on `OpenVINOQuantizer` please see the `documentation `_. After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. ``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. @@ -208,6 +209,34 @@ After that the FX Graph can utilize OpenVINO optimizations using `torch.compile( The optimized model is using low-level kernels designed specifically for Intel CPU. This should significantly speed up inference time in comparison with the eager model. +5. Optional: Improve quantized model metrics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +NNCF implements advanced quantization algorithms like SmoothQuant and BiasCorrection, which help +improve the quantized model metrics while minimizing the output discrepancies between the original and compressed models. +These advanced NNCF algorithms can be accessed via the NNCF `quantize_pt2e` API: + +.. code-block:: python + + from nncf.experimental.torch.fx import quantize_pt2e + + calibration_loader = torch.utils.data.DataLoader(...) + + + def transform_fn(data_item): + images, _ = data_item + return images + + + calibration_dataset = nncf.Dataset(calibration_loader, transform_fn) + quantized_model = quantize_pt2e( + exported_model, quantizer, calibration_dataset, smooth_quant=True, fast_bias_correction=False + ) + + +For further details, please see the `documentation `_ +and a complete `example on Resnet18 quantization `_. + Conclusion ------------ From 82a47a5672d3150dc2fad7241bbd283ca2c1c456 Mon Sep 17 00:00:00 2001 From: daniil-lyakhov Date: Mon, 24 Feb 2025 15:42:59 +0100 Subject: [PATCH 06/11] Comments --- prototype_source/openvino_quantizer.rst | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index 90783312453..e7857c5acf3 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -1,7 +1,7 @@ PyTorch 2 Export Quantization for OpenVINO torch.compile backend. =========================================================================== -**Authors**: `Daniil Lyakhov `_, `Alexander Suslov `_, `Aamir Nazir `_ +**Authors**: `Daniil Lyakhov `_, `Aamir Nazir `_, `Alexander Suslov `_, `Yamini Nimmagadda `_, `Alexander Kozlov `_ Prerequisites -------------- @@ -11,18 +11,21 @@ Prerequisites Introduction -------------- +**This is an experimental feature, the quantization API is subject to change.** + This tutorial demonstrates how to use `OpenVINOQuantizer` from `Neural Network Compression Framework (NNCF) `_ in PyTorch 2 Export Quantization flow to generate a quantized model customized for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. +`OpenVINOQuantizer` unlocks the full potential of low-precision OpenVINO kernels due to the placement of quantizers designed specifically for the OpenVINO. -The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and performs quantization transformations on top of the ATen graph. +The PyTorch 2 export quantization flow uses the torch.export to capture the model into a graph and performs quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. OpenVINO backend compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. The quantization flow mainly includes four steps: -- Step 1: Install OpenVINO and NNCF. -- Step 2: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. -- Step 3: Apply the PyTorch 2 Export Quantization flow with OpenVINOQuantizer based on the captured FX Graph. -- Step 4: Lower the quantized model into OpenVINO representation with the API `torch.compile `_. +- Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. +- Step 2: Apply the PyTorch 2 Export Quantization flow with OpenVINOQuantizer based on the captured FX Graph. +- Step 3: Lower the quantized model into OpenVINO representation with the API `torch.compile `_. +- Optional step 4: : Improve quantized model metrics via `quantize_pt2 `_ method. The high-level architecture of this flow could look like this: @@ -61,7 +64,7 @@ Post Training Quantization Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model `_ for post training quantization. -1. OpenVINO and NNCF installation +Prerequisite: OpenVINO and NNCF installation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ OpenVINO and NNCF could be easily installed via `pip distribution `_: @@ -71,7 +74,7 @@ OpenVINO and NNCF could be easily installed via `pip distribution `_ functionality. @@ -209,7 +212,7 @@ After that the FX Graph can utilize OpenVINO optimizations using `torch.compile( The optimized model is using low-level kernels designed specifically for Intel CPU. This should significantly speed up inference time in comparison with the eager model. -5. Optional: Improve quantized model metrics +4. Optional: Improve quantized model metrics ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ NNCF implements advanced quantization algorithms like SmoothQuant and BiasCorrection, which help From 26f044b66eecd6cd65abccb4f25bca7199e60fe9 Mon Sep 17 00:00:00 2001 From: daniil-lyakhov Date: Mon, 24 Feb 2025 15:57:32 +0100 Subject: [PATCH 07/11] fold_quantize=False --- prototype_source/openvino_quantizer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index e7857c5acf3..a4940c7caf8 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -189,7 +189,7 @@ Finally, we will convert the calibrated Model to a quantized Model. ``convert_pt .. code-block:: python - quantized_model = convert_pt2e(prepared_model) + quantized_model = convert_pt2e(prepared_model, fold_quantize=False) After these steps, we finished running the quantization flow, and we will get the quantized model. From 75d35495ccea3b5d24cfb9c870fd37b14e07138e Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 11 Apr 2025 14:42:46 +0200 Subject: [PATCH 08/11] Update prototype_source/openvino_quantizer.rst Co-authored-by: Alexander Suslov --- prototype_source/openvino_quantizer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index a4940c7caf8..2f8279699c6 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -25,7 +25,7 @@ The quantization flow mainly includes four steps: - Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. - Step 2: Apply the PyTorch 2 Export Quantization flow with OpenVINOQuantizer based on the captured FX Graph. - Step 3: Lower the quantized model into OpenVINO representation with the API `torch.compile `_. -- Optional step 4: : Improve quantized model metrics via `quantize_pt2 `_ method. +- Optional step 4: : Improve quantized model metrics via `quantize_pt2e `_ method. The high-level architecture of this flow could look like this: From f09a85f4bda188d2a16a779127ffd9ea7b5a5e70 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 14 Apr 2025 16:04:52 +0200 Subject: [PATCH 09/11] Spelling / comments --- en-wordlist.txt | 11 +++++++++++ prototype_source/openvino_quantizer.rst | 14 ++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/en-wordlist.txt b/en-wordlist.txt index 6a794e7786f..baf75d75ac0 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -698,3 +698,14 @@ TorchServe Inductor’s onwards recompilations +BiasCorrection +ELU +GELU +NNCF +OpenVINO +OpenVINOQuantizer +PReLU +Quantizer +SmoothQuant +quantizer +quantizers \ No newline at end of file diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index 2f8279699c6..261b37ba766 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -11,13 +11,15 @@ Prerequisites Introduction -------------- -**This is an experimental feature, the quantization API is subject to change.** +.. note:: + + This is an experimental feature, the quantization API is subject to change. This tutorial demonstrates how to use `OpenVINOQuantizer` from `Neural Network Compression Framework (NNCF) `_ in PyTorch 2 Export Quantization flow to generate a quantized model customized for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. `OpenVINOQuantizer` unlocks the full potential of low-precision OpenVINO kernels due to the placement of quantizers designed specifically for the OpenVINO. The PyTorch 2 export quantization flow uses the torch.export to capture the model into a graph and performs quantization transformations on top of the ATen graph. -This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. +This approach is expected to have significantly higher model coverage, improved flexibility, and a simplified UX. OpenVINO backend compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. The quantization flow mainly includes four steps: @@ -134,7 +136,7 @@ Below is the list of essential parameters and their description: OpenVINOQuantizer(preset=nncf.QuantizationPreset.MIXED) -* ``model_type`` - used to specify quantization scheme required for specific type of the model. Transformer is the only supported special quantization scheme to preserve accuracy after quantization of Transformer models (BERT, DistilBERT, etc.). None is default, i.e. no specific scheme is defined. +* ``model_type`` - used to specify quantization scheme required for specific type of the model. Transformer is the only supported special quantization scheme to preserve accuracy after quantization of Transformer models (BERT, Llama, etc.). None is default, i.e. no specific scheme is defined. .. code-block:: python @@ -169,7 +171,7 @@ Below is the list of essential parameters and their description: OpenVINOQuantizer(target_device=nncf.TargetDevice.CPU) -For futher details on `OpenVINOQuantizer` please see the `documentation `_. +For further details on `OpenVINOQuantizer` please see the `documentation `_. After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. ``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. @@ -215,8 +217,8 @@ This should significantly speed up inference time in comparison with the eager m 4. Optional: Improve quantized model metrics ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -NNCF implements advanced quantization algorithms like SmoothQuant and BiasCorrection, which help -improve the quantized model metrics while minimizing the output discrepancies between the original and compressed models. +NNCF implements advanced quantization algorithms like `SmoothQuant `_ and `BiasCorrection `_, which help +to improve the quantized model metrics while minimizing the output discrepancies between the original and compressed models. These advanced NNCF algorithms can be accessed via the NNCF `quantize_pt2e` API: .. code-block:: python From f3137be0d646deae6368bb4d08e902ecbb4fb224 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 16 Apr 2025 11:16:33 +0200 Subject: [PATCH 10/11] prototype_index.rst is updated --- prototype_source/openvino_quantizer.rst | 2 +- prototype_source/prototype_index.rst | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index 261b37ba766..4d98b2b7dd5 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -1,4 +1,4 @@ -PyTorch 2 Export Quantization for OpenVINO torch.compile backend. +PyTorch 2 Export Quantization for OpenVINO torch.compile Backend. =========================================================================== **Authors**: `Daniil Lyakhov `_, `Aamir Nazir `_, `Alexander Suslov `_, `Yamini Nimmagadda `_, `Alexander Kozlov `_ diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index 927f5f694b8..489e79fe011 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -96,6 +96,13 @@ Prototype features are not available as part of binary distributions like PyPI o :link: ../prototype/pt2e_quant_x86_inductor.html :tags: Quantization +.. customcarditem:: + :header: PyTorch 2 Export Quantization for OpenVINO torch.compile Backend + :card_description: Learn how to use PT2 Export Quantization with OpenVINO torch.compile Backend. + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../prototype/openvino_quantizer.html + :tags: Quantization + .. Sparsity .. customcarditem:: From b7d2781a4ba4f370dec29ee4a1713dba69bea547 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Wed, 16 Apr 2025 11:10:29 +0200 Subject: [PATCH 11/11] Apply suggestions from code review Co-authored-by: Svetlana Karslioglu --- prototype_source/openvino_quantizer.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/prototype_source/openvino_quantizer.rst b/prototype_source/openvino_quantizer.rst index 4d98b2b7dd5..4d9a73edd64 100644 --- a/prototype_source/openvino_quantizer.rst +++ b/prototype_source/openvino_quantizer.rst @@ -15,10 +15,10 @@ Introduction This is an experimental feature, the quantization API is subject to change. -This tutorial demonstrates how to use `OpenVINOQuantizer` from `Neural Network Compression Framework (NNCF) `_ in PyTorch 2 Export Quantization flow to generate a quantized model customized for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. -`OpenVINOQuantizer` unlocks the full potential of low-precision OpenVINO kernels due to the placement of quantizers designed specifically for the OpenVINO. +This tutorial demonstrates how to use ``OpenVINOQuantizer`` from `Neural Network Compression Framework (NNCF) `_ in PyTorch 2 Export Quantization flow to generate a quantized model customized for the `OpenVINO torch.compile backend `_ and explains how to lower the quantized model into the `OpenVINO `_ representation. +``OpenVINOQuantizer`` unlocks the full potential of low-precision OpenVINO kernels due to the placement of quantizers designed specifically for the OpenVINO. -The PyTorch 2 export quantization flow uses the torch.export to capture the model into a graph and performs quantization transformations on top of the ATen graph. +The PyTorch 2 export quantization flow uses ``torch.export`` to capture the model into a graph and performs quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, improved flexibility, and a simplified UX. OpenVINO backend compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. @@ -26,7 +26,7 @@ The quantization flow mainly includes four steps: - Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism `_. - Step 2: Apply the PyTorch 2 Export Quantization flow with OpenVINOQuantizer based on the captured FX Graph. -- Step 3: Lower the quantized model into OpenVINO representation with the API `torch.compile `_. +- Step 3: Lower the quantized model into OpenVINO representation with the `torch.compile `_ API. - Optional step 4: : Improve quantized model metrics via `quantize_pt2e `_ method. The high-level architecture of this flow could look like this: