diff --git a/README.md b/README.md index 4347feec8cf..65bbb7f3e5c 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,9 @@ def transform_fn(data_item): calibration_dataset = nncf.Dataset(val_dataset, transform_fn) # Step 3: Run the quantization pipeline quantized_model = nncf.quantize(model, calibration_dataset) +# Step 4: Remove auxiliary layers and operations added during the quantization process, +# resulting in a clean, fully quantized model ready for deployment. +stripped_model = nncf.strip(quantized_model) ``` diff --git a/docs/usage/training_time_compression/quantization_aware_training/Usage.md b/docs/usage/training_time_compression/quantization_aware_training/Usage.md index 3a5fbffb096..8ae7163e677 100644 --- a/docs/usage/training_time_compression/quantization_aware_training/Usage.md +++ b/docs/usage/training_time_compression/quantization_aware_training/Usage.md @@ -1,7 +1,7 @@ -# Use NNCF for Quantization Aware Training in PyTorch +# Use NNCF for Quantization Aware Training -This is a step-by-step tutorial on how to integrate the NNCF package into the existing PyTorch project (please see the [TensorFlow quantization documentation](../other_algorithms/LegacyQuantization.md) for integration tutorial for the existing TensorFlow project). -The use case implies that the user already has a training pipeline that reproduces training of the model in the floating point precision and pretrained model. +This is a step-by-step tutorial on how to integrate the NNCF package into the existing PyTorch or TensorFlow projects. +The use case implies that the user already has a training pipeline that reproduces training of the model in the floating point precision and pretrained model. The task is to prepare this model for accelerated inference by simulating the compression at train time. Please refer to this [document](/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md) for details of the implementation. @@ -11,11 +11,24 @@ Please refer to this [document](/docs/usage/training_time_compression/other_algo Quantize the model using the [Post Training Quantization](../../post_training_compression/post_training_quantization/Usage.md) method. +
PyTorch + ```python model = TorchModel() # instance of torch.nn.Module quantized_model = nncf.quantize(model, ...) ``` +
+ +
TensorFlow + +```python +model = TensorFlowModel() # instance of tf.keras.Model +quantized_model = nncf.quantize(model, ...) +``` + +
+ ### Step 2: Run the training pipeline At this point, the NNCF is fully integrated into your training pipeline. @@ -27,18 +40,39 @@ Important points you should consider when training your networks with compressio ### Step 3: Export the compressed model -After the compressed model has been fine-tuned to acceptable accuracy and compression stages, you can export it. There are two ways to export a model: +After the compressed model has been fine-tuned to acceptable accuracy and compression stages, you can export it. + +
PyTorch + +Trace the model via inference in framework operations. -1. Trace the model via inference in framework operations. +```python +# To OpenVINO format +import openvino as ov +ov_quantized_model = ov.convert_model(quantized_model.cpu(), example_input=dummy_input) +``` + +
+ +
TensorFlow + +```python +# To OpenVINO format +import openvino as ov + +# Removes auxiliary layers and operations added during the quantization process, +# resulting in a clean, fully quantized model ready for deployment. +stripped_model = nncf.strip(quantized_model) + +ov_quantized_model = ov.convert_model(stripped_model) +``` - ```python - # To OpenVINO format - import openvino as ov - ov_quantized_model = ov.convert_model(quantized_model.cpu(), example_input=dummy_input) - ``` +
## Saving and loading compressed models +
PyTorch + The complete information about compression is defined by a compressed model and a NNCF config. The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules intoduced by NNCF. The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving and passed to the @@ -46,8 +80,6 @@ The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving The quantized model saving allows to load quantized modules to the target model in a new python process and requires only example input for the target module, corresponding NNCF config and the quantized model state dict. -### Saving and loading compressed models in PyTorch - ```python # save part quantized_model = nncf.quantize(model, calibration_dataset) @@ -70,10 +102,53 @@ quantized_model.load_state_dict(state_dict) You can save the `compressed_model` object `torch.save` as usual: via `state_dict` and `load_state_dict` methods. +
+ +
TensorFlow + +To save a model checkpoint, use the following API: + +```python +from nncf.tensorflow import ConfigState +from nncf.tensorflow import get_config +from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback + +nncf_config = get_config(quantized_model) +checkpoint = tf.train.Checkpoint(model=quantized_model, + nncf_config_state=ConfigState(nncf_config), + ... # the rest of the user-defined objects to save + ) +callbacks = [] +callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint)) +... +quantized_model.fit(..., callbacks=callbacks) +``` + +To restore the model from checkpoint, use the following API: + +```python +from nncf.tensorflow import ConfigState +from nncf.tensorflow import load_from_config + +checkpoint = tf.train.Checkpoint(nncf_config_state=ConfigState()) +checkpoint.restore(path_to_checkpoint) + +quantized_model = load_from_config(model, checkpoint.nncf_config_state.config) + +checkpoint = tf.train.Checkpoint(model=quantized_model + ... # the rest of the user-defined objects to load + ) +checkpoint.restore(path_to_checkpoint) +``` + +
+ ## Advanced usage ### Compression of custom modules +
PyTorch + With no target model code modifications, NNCF only supports native PyTorch modules with respect to trainable parameter (weight) compressed, such as `torch.nn.Conv2d`. If your model contains a custom, non-PyTorch standard module with trainable weights that should be compressed, you can register it using the `@nncf.register_module` decorator: @@ -91,4 +166,9 @@ If registered module should be ignored by specific algorithms use `ignored_algor In the example above, the NNCF-compressed models that contain instances of `MyModule` will have the corresponding modules extended with functionality that will allow NNCF to quantize the `weight` parameter of `MyModule` before it takes part in `MyModule`'s `forward` calculation. -See a PyTorch [example](/examples/quantization_aware_training/torch/resnet18/README.md) for **Quantization** Compression scenario on Tiny ImageNet-200 dataset. +
+ +## Examples + +- See a PyTorch [example](/examples/quantization_aware_training/torch/resnet18/README.md) for **Quantization** Compression scenario on Tiny ImageNet-200 dataset. +- See a TensorFlow [example](/examples/quantization_aware_training/tensorflow/mobilenet_v2/README.md) for **Quantization** Compression scenario on imagenette/320px-v2 dataset. diff --git a/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py b/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py index 8e175f7dd3f..5f22d516e22 100644 --- a/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py +++ b/examples/post_training_quantization/tensorflow/mobilenet_v2/main.py @@ -151,8 +151,8 @@ def transform_fn(data_item): ############################################################################### # Benchmark performance, calculate compression rate and validate accuracy -ov_model = ov.convert_model(tf_model, share_weights=False) -ov_quantized_model = ov.convert_model(tf_quantized_model, share_weights=False) +ov_model = ov.convert_model(tf_model) +ov_quantized_model = ov.convert_model(tf_quantized_model) fp32_ir_path = ROOT / "mobilenet_v2_fp32.xml" ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False) diff --git a/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py b/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py index cf3bc372887..233ec512727 100644 --- a/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py +++ b/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py @@ -167,8 +167,8 @@ def transform_fn(data_item): ############################################################################### # Benchmark performance, calculate compression rate and validate accuracy -ov_model = ov.convert_model(tf_model, share_weights=False) -ov_quantized_model = ov.convert_model(stripped_model, share_weights=False) +ov_model = ov.convert_model(tf_model) +ov_quantized_model = ov.convert_model(stripped_model) fp32_ir_path = ROOT / "mobilenet_v2_fp32.xml" ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False) diff --git a/nncf/tensorflow/__init__.py b/nncf/tensorflow/__init__.py index 1bd6cef967a..cc71a9d6c91 100644 --- a/nncf/tensorflow/__init__.py +++ b/nncf/tensorflow/__init__.py @@ -44,6 +44,8 @@ ) from nncf.tensorflow.helpers import create_compressed_model as create_compressed_model from nncf.tensorflow.helpers.callback_creation import create_compression_callbacks as create_compression_callbacks +from nncf.tensorflow.helpers.model_creation import get_config +from nncf.tensorflow.helpers.model_creation import load_from_config from nncf.tensorflow.initialization import register_default_init_args as register_default_init_args from nncf.tensorflow.pruning.filter_pruning import algorithm as filter_pruning_algorithm @@ -51,3 +53,4 @@ from nncf.tensorflow.quantization import algorithm as quantization_algorithm from nncf.tensorflow.sparsity.magnitude import algorithm as magnitude_sparsity_algorithm from nncf.tensorflow.sparsity.rb import algorithm as rb_sparsity_algorithm +from nncf.tensorflow.utils.state import ConfigState diff --git a/nncf/tensorflow/helpers/model_creation.py b/nncf/tensorflow/helpers/model_creation.py index 3edbb41880e..b3c506edac4 100644 --- a/nncf/tensorflow/helpers/model_creation.py +++ b/nncf/tensorflow/helpers/model_creation.py @@ -18,19 +18,25 @@ from nncf import NNCFConfig from nncf.api.compression import CompressionAlgorithmController from nncf.common.compression import BaseCompressionAlgorithmController as BaseController +from nncf.common.deprecation import warning_deprecated from nncf.common.utils.api_marker import api from nncf.config.extractors import extract_algorithm_names from nncf.config.telemetry_extractors import CompressionStartedFromConfig from nncf.config.utils import is_experimental_quantization from nncf.telemetry import tracked_function from nncf.telemetry.events import NNCF_TF_CATEGORY +from nncf.telemetry.extractors import FunctionCallTelemetryExtractor from nncf.tensorflow.accuracy_aware_training.keras_model_utils import accuracy_aware_fit from nncf.tensorflow.algorithm_selector import NoCompressionAlgorithmBuilder from nncf.tensorflow.algorithm_selector import get_compression_algorithm_builder from nncf.tensorflow.api.composite_compression import TFCompositeCompressionAlgorithmBuilder from nncf.tensorflow.api.compression import TFCompressionAlgorithmBuilder +from nncf.tensorflow.graph.model_transformer import TFModelTransformer +from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout from nncf.tensorflow.graph.utils import is_keras_layer_model from nncf.tensorflow.helpers.utils import get_built_model +from nncf.tensorflow.quantization.algorithm import QuantizationBuilder +from nncf.tensorflow.quantization.algorithm import TFQuantizationSetup def create_compression_algorithm_builder(config: NNCFConfig, should_init: bool) -> TFCompressionAlgorithmBuilder: @@ -80,6 +86,27 @@ def create_compressed_model( :return: A tuple of the compression controller for the requested algorithm(s) and the model object with additional modifications necessary to enable algorithm-specific compression during fine-tuning. """ + + warning_deprecated( + "The 'nncf.tensorflow.create_compressed_model' function is deprecated and will be removed in a " + "future release.\n" + "To perform post training quantization (PTQ) or quantization aware training (QAT)," + " use the nncf.quantize() API:\n" + " - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#post-training-quantization\n" + " - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#training-time-quantization\n" + "Examples:\n" + " - https://github.com/openvinotoolkit/nncf/tree/develop/examples/post_training_quantization/tensorflow\n" + " - https://github.com/openvinotoolkit/nncf/tree/develop/examples/quantization_aware_training/tensorflow" + ) + return create_compressed_model_impl(model, config, compression_state) + + +def create_compressed_model_impl( + model: tf.keras.Model, config: NNCFConfig, compression_state: Optional[Dict[str, Any]] = None +) -> Tuple[CompressionAlgorithmController, tf.keras.Model]: + """ + Implementation of the create_compressed_model() method. + """ if is_experimental_quantization(config): if is_keras_layer_model(model): raise ValueError( @@ -126,3 +153,47 @@ def get_input_signature(config: NNCFConfig): input_signature.append(tf.TensorSpec(shape=shape, dtype=tf.float32)) return input_signature if len(input_signature) > 1 else input_signature[0] + + +@tracked_function( + NNCF_TF_CATEGORY, + [ + FunctionCallTelemetryExtractor("nncf.tensorflow.load_from_config"), + ], +) +def load_from_config(model: tf.keras.Model, config: Dict[str, Any]) -> tf.keras.Model: + """ + Recovers additional modules from given config. + Does not recover additional modules weights as they are located in a corresponded checkpoint file. + + :param model: TensorFlow model. + :parem config: Config. + :return: tf.keras.Model builded from given model with additional layers recovered from given config. + """ + quantizer_setup_state = config["quantization"]["quantizer_setup"] + quantizer_setup = TFQuantizationSetup.from_state(quantizer_setup_state) + + transformation_layout = TFTransformationLayout() + # pylint: disable=protected-access + insertion_commands, _ = QuantizationBuilder.build_insertion_commands_for_quantizer_setup(quantizer_setup) + for command in insertion_commands: + transformation_layout.register(command) + model_transformer = TFModelTransformer(model) + return model_transformer.transform(transformation_layout) + + +@tracked_function( + NNCF_TF_CATEGORY, + [ + FunctionCallTelemetryExtractor("nncf.tensorflow.get_config"), + ], +) +def get_config(model: tf.keras.Model) -> Dict[str, Any]: + """ + Extracts the config from the model. + + :param model: Model. + :return: Config. + """ + config = getattr(model, "_nncf_config") + return config diff --git a/nncf/tensorflow/quantization/algorithm.py b/nncf/tensorflow/quantization/algorithm.py index 7a8a23a0ed7..d2cb5fc450d 100644 --- a/nncf/tensorflow/quantization/algorithm.py +++ b/nncf/tensorflow/quantization/algorithm.py @@ -346,14 +346,17 @@ def _get_half_range( return True return False - def _create_quantizer(self, name: str, qspec: TFQuantizerSpec) -> Quantizer: + @staticmethod + def _create_quantizer(name: str, qspec: TFQuantizerSpec) -> Quantizer: quantizer_cls = NNCF_QUANTIZATION_OPERATIONS.get(qspec.mode) return quantizer_cls(name, qspec) - def _build_insertion_commands_for_quantizer_setup( - self, quantizer_setup: TFQuantizationSetup - ) -> List[TFInsertionCommand]: + @staticmethod + def build_insertion_commands_for_quantizer_setup( + quantizer_setup: TFQuantizationSetup, + ) -> Tuple[List[TFInsertionCommand], List[str]]: insertion_commands = [] + op_names = [] quantization_points = quantizer_setup.get_quantization_points() non_unified_scales_quantization_point_ids = set(range(len(quantization_points))) @@ -365,7 +368,7 @@ def _build_insertion_commands_for_quantizer_setup( quantizer_spec = qp.quantizer_spec op_name = qp.op_name + "/unified_scale_group" quantizer = FakeQuantize(quantizer_spec, name=op_name) - self._op_names.append(quantizer.op_name) + op_names.append(quantizer.op_name) target_points = [] for us_qp_id in unified_scales_group: non_unified_scales_quantization_point_ids.discard(us_qp_id) @@ -387,24 +390,26 @@ def _build_insertion_commands_for_quantizer_setup( quantizer_spec = quantization_point.quantizer_spec target_point = quantization_point.target_point if quantization_point.is_weight_quantization(): - quantizer = self._create_quantizer(op_name, quantizer_spec) - self._op_names.append(op_name) + quantizer = QuantizationBuilder._create_quantizer(op_name, quantizer_spec) + op_names.append(op_name) else: quantizer = FakeQuantize(quantizer_spec, name=op_name) - self._op_names.append(quantizer.op_name) + op_names.append(quantizer.op_name) command = TFInsertionCommand( target_point=target_point, callable_object=quantizer, priority=TransformationPriority.QUANTIZATION_PRIORITY, ) insertion_commands.append(command) - return insertion_commands + return insertion_commands, op_names def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLayout: transformations = TFTransformationLayout() if self._quantizer_setup is None: self._quantizer_setup = self._get_quantizer_setup(model) - insertion_commands = self._build_insertion_commands_for_quantizer_setup(self._quantizer_setup) + insertion_commands, self._op_names = QuantizationBuilder.build_insertion_commands_for_quantizer_setup( + self._quantizer_setup + ) for command in insertion_commands: transformations.register(command) return transformations diff --git a/nncf/tensorflow/quantization/quantize_model.py b/nncf/tensorflow/quantization/quantize_model.py index 02da5a28ab4..20d164070bb 100644 --- a/nncf/tensorflow/quantization/quantize_model.py +++ b/nncf/tensorflow/quantization/quantize_model.py @@ -28,7 +28,7 @@ from nncf.quantization.advanced_parameters import apply_advanced_parameters_to_config from nncf.scopes import IgnoredScope from nncf.scopes import convert_ignored_scope_to_list -from nncf.tensorflow.helpers.model_creation import create_compressed_model +from nncf.tensorflow.helpers.model_creation import create_compressed_model_impl DEFAULT_RANGE_TYPE = "mean_min_max" @@ -176,6 +176,12 @@ def quantize_impl( ] ) - _, compressed_model = create_compressed_model(model=model, config=nncf_config) + compression_ctrl, compressed_model = create_compressed_model_impl(model=model, config=nncf_config) + + # NOTE: We set the config here to properly save/load the quantized model during training into tf.train.Checkpoint. + # You can obtain that config via the nncf.tensorflow.get_config() method and save/load it to/from + # tf.train.Checkpoint using the nncf.tensorflow.ConfigState class. + config = compression_ctrl.get_compression_state()["builder_state"] + setattr(compressed_model, "_nncf_config", config) return compressed_model diff --git a/nncf/tensorflow/utils/state.py b/nncf/tensorflow/utils/state.py index df223c62623..1cac500579a 100644 --- a/nncf/tensorflow/utils/state.py +++ b/nncf/tensorflow/utils/state.py @@ -10,11 +10,12 @@ # limitations under the License. import json -from typing import Any, Dict +from typing import Any, Dict, Optional import tensorflow as tf from nncf.common.compression import BaseCompressionAlgorithmController +from nncf.tensorflow.quantization.algorithm import TFQuantizationSetup # TODO(achurkin): remove pylint ignore after 120296 ticked is fixed @@ -86,3 +87,37 @@ def deserialize(self, string_value: str) -> None: :param string_value: A serialized compression state. """ self._state = json.loads(string_value) + + +class ConfigState(tf.train.experimental.PythonState): + """ + Used to save/load a config into the tf.train.Checkpoint. + """ + + def __init__(self, config: Optional[Dict[str, Any]] = None): + """ + :param config: Config. + """ + self.config = config + + def serialize(self) -> str: + """ + Callback to serialize the config. + + :return: A serialized config. + """ + quantizer_setup_state = self.config["quantization"]["quantizer_setup"] + data = { + "quantization": { + "quantizer_setup": TFQuantizationSetup.from_state(quantizer_setup_state).get_state(), + } + } + return json.dumps(data) + + def deserialize(self, string_value: str) -> None: + """ + Callback to deserialize the model config. + + :param string_value: A serialized model config. + """ + self.config = json.loads(string_value) diff --git a/nncf/torch/model_creation.py b/nncf/torch/model_creation.py index 8674cb3d0ff..4d55f134fb7 100644 --- a/nncf/torch/model_creation.py +++ b/nncf/torch/model_creation.py @@ -106,7 +106,7 @@ def create_compressed_model( warning_deprecated( "The 'nncf.torch.create_compressed_model' function is deprecated and will be removed in a future release.\n" "To perform post training quantization (PTQ) or quantization aware training (QAT)," - " use the new nncf.quantize() API:\n" + " use the nncf.quantize() API:\n" " - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#post-training-quantization\n" " - https://github.com/openvinotoolkit/nncf?tab=readme-ov-file#training-time-quantization\n" "Examples:\n"