Skip to content

Commit

Permalink
add tests (#3208)
Browse files Browse the repository at this point in the history
### Changes

- Add more tests to be sure `nncf.strip()` doesn't change model without
compression.
  • Loading branch information
andrey-churkin authored Jan 29, 2025
1 parent e8ff50c commit 0333814
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
3 changes: 3 additions & 0 deletions nncf/tensorflow/strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def strip(model: tf.keras.Model, do_copy: bool = True) -> tf.keras.Model:
will return the currently associated model object "stripped" in-place.
:return: The stripped model.
"""
if not isinstance(model, tf.keras.Model):
return model

# Check to understand if the model is after NNCF or not.
wrapped_layers = collect_wrapped_layers(model)
if not wrapped_layers:
Expand Down
47 changes: 47 additions & 0 deletions tests/tensorflow/quantization/test_strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test
from tests.tensorflow.helpers import get_basic_two_conv_test_model
from tests.tensorflow.quantization.utils import get_basic_quantization_config
from tests.tensorflow.test_models.mobilenet_v2 import MobileNetV2
from tests.tensorflow.test_models.retinanet import RetinaNet
from tests.tensorflow.test_models.yolo_v4 import YOLOv4


def test_strip():
Expand Down Expand Up @@ -89,3 +92,47 @@ def test_strip_api_do_copy(do_copy):
assert id(stripped_model) != id(compressed_model)
else:
assert id(stripped_model) == id(compressed_model)


class SimpleModel(tf.keras.Model):
def __init__(self):
super().__init__()
self._conv = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu")

self._bn_0 = tf.keras.layers.BatchNormalization()
self._bn_1 = tf.keras.layers.BatchNormalization()
self._add = tf.keras.layers.Add()
self._flatten = tf.keras.layers.Flatten()

def call(self, inputs, training=None, mask=None):
input_0, input_1 = inputs

x_0 = self._conv(input_0)
x_0 = self._bn_0(x_0, training=training)

x_1 = self._conv(input_1)
x_1 = self._bn_1(x_1, training=training)

x_0 = self._flatten(x_0)
x_1 = self._flatten(x_1)
outputs = self._add([x_0, x_1])
return outputs

def get_config(self):
raise NotImplementedError


def create_sequential_model():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(None, None, 3)))
model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=(1, 1), strides=2, activation="relu"))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dense(2, activation="relu"))
return model


@pytest.mark.parametrize("model_fn", (MobileNetV2, RetinaNet, SimpleModel, YOLOv4, create_sequential_model))
def test_strip_api_no_compression(model_fn):
model = model_fn()
stripped_model = nncf.strip(model)
assert id(stripped_model) == id(model)

0 comments on commit 0333814

Please sign in to comment.