From 0333814fc32092c28799dc2fbe468c11f32f598c Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Wed, 29 Jan 2025 09:51:16 +0000 Subject: [PATCH] add tests (#3208) ### Changes - Add more tests to be sure `nncf.strip()` doesn't change model without compression. --- nncf/tensorflow/strip.py | 3 ++ tests/tensorflow/quantization/test_strip.py | 47 +++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/nncf/tensorflow/strip.py b/nncf/tensorflow/strip.py index 72a1e1123ca..2159ba04cc2 100644 --- a/nncf/tensorflow/strip.py +++ b/nncf/tensorflow/strip.py @@ -37,6 +37,9 @@ def strip(model: tf.keras.Model, do_copy: bool = True) -> tf.keras.Model: will return the currently associated model object "stripped" in-place. :return: The stripped model. """ + if not isinstance(model, tf.keras.Model): + return model + # Check to understand if the model is after NNCF or not. wrapped_layers = collect_wrapped_layers(model) if not wrapped_layers: diff --git a/tests/tensorflow/quantization/test_strip.py b/tests/tensorflow/quantization/test_strip.py index 89d538882f4..541c01af846 100644 --- a/tests/tensorflow/quantization/test_strip.py +++ b/tests/tensorflow/quantization/test_strip.py @@ -17,6 +17,9 @@ from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.helpers import get_basic_two_conv_test_model from tests.tensorflow.quantization.utils import get_basic_quantization_config +from tests.tensorflow.test_models.mobilenet_v2 import MobileNetV2 +from tests.tensorflow.test_models.retinanet import RetinaNet +from tests.tensorflow.test_models.yolo_v4 import YOLOv4 def test_strip(): @@ -89,3 +92,47 @@ def test_strip_api_do_copy(do_copy): assert id(stripped_model) != id(compressed_model) else: assert id(stripped_model) == id(compressed_model) + + +class SimpleModel(tf.keras.Model): + def __init__(self): + super().__init__() + self._conv = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu") + + self._bn_0 = tf.keras.layers.BatchNormalization() + self._bn_1 = tf.keras.layers.BatchNormalization() + self._add = tf.keras.layers.Add() + self._flatten = tf.keras.layers.Flatten() + + def call(self, inputs, training=None, mask=None): + input_0, input_1 = inputs + + x_0 = self._conv(input_0) + x_0 = self._bn_0(x_0, training=training) + + x_1 = self._conv(input_1) + x_1 = self._bn_1(x_1, training=training) + + x_0 = self._flatten(x_0) + x_1 = self._flatten(x_1) + outputs = self._add([x_0, x_1]) + return outputs + + def get_config(self): + raise NotImplementedError + + +def create_sequential_model(): + model = tf.keras.Sequential() + model.add(tf.keras.layers.Input(shape=(None, None, 3))) + model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=(1, 1), strides=2, activation="relu")) + model.add(tf.keras.layers.BatchNormalization()) + model.add(tf.keras.layers.Dense(2, activation="relu")) + return model + + +@pytest.mark.parametrize("model_fn", (MobileNetV2, RetinaNet, SimpleModel, YOLOv4, create_sequential_model)) +def test_strip_api_no_compression(model_fn): + model = model_fn() + stripped_model = nncf.strip(model) + assert id(stripped_model) == id(model)