|
4 | 4 |
|
5 | 5 | import pytest
|
6 | 6 | import os
|
| 7 | +import numpy as np |
| 8 | +import openvino as ov |
7 | 9 | from pathlib import Path
|
8 |
| -from openvino.utils import deprecated, get_cmake_path |
9 |
| -from tests.utils.helpers import compare_models, get_relu_model |
| 10 | +from openvino.utils import deprecated, get_cmake_path, make_postponed_constant |
| 11 | +from tests.utils.helpers import compare_models, get_relu_model, create_filenames_for_ir |
10 | 12 |
|
11 | 13 |
|
12 | 14 | def test_compare_functions():
|
@@ -94,3 +96,95 @@ def mock_walk(path):
|
94 | 96 | result = get_cmake_path()
|
95 | 97 |
|
96 | 98 | assert result == ""
|
| 99 | + |
| 100 | + |
| 101 | +class Maker: |
| 102 | + def __init__(self): |
| 103 | + self.calls_count = 0 |
| 104 | + |
| 105 | + def __call__(self, tensor: ov.Tensor) -> None: |
| 106 | + self.calls_count += 1 |
| 107 | + tensor_data = np.array([2, 2, 2, 2], dtype=np.float32).reshape(1, 1, 2, 2) |
| 108 | + ov.Tensor(tensor_data).copy_to(tensor) |
| 109 | + |
| 110 | + def called_times(self): |
| 111 | + return self.calls_count |
| 112 | + |
| 113 | + |
| 114 | +def create_model(maker): |
| 115 | + input_shape = ov.Shape([1, 2, 1, 2]) |
| 116 | + param_node = ov.opset13.parameter(input_shape, ov.Type.f32, name="data") |
| 117 | + |
| 118 | + postponned_constant = make_postponed_constant(ov.Type.f32, input_shape, maker) |
| 119 | + |
| 120 | + add_1 = ov.opset13.add(param_node, postponned_constant) |
| 121 | + |
| 122 | + const_2 = ov.op.Constant(ov.Type.f32, input_shape, [1, 2, 3, 4]) |
| 123 | + add_2 = ov.opset13.add(add_1, const_2) |
| 124 | + |
| 125 | + return ov.Model(add_2, [param_node], "test_model") |
| 126 | + |
| 127 | + |
| 128 | +@pytest.fixture |
| 129 | +def prepare_ir_paths(request, tmp_path): |
| 130 | + xml_path, bin_path = create_filenames_for_ir(request.node.name, tmp_path) |
| 131 | + |
| 132 | + yield xml_path, bin_path |
| 133 | + |
| 134 | + # IR Files deletion should be done after `Model` is destructed. |
| 135 | + # It may be achieved by splitting scopes (`Model` will be destructed |
| 136 | + # just after test scope finished), or by calling `del Model` |
| 137 | + os.remove(xml_path) |
| 138 | + os.remove(bin_path) |
| 139 | + |
| 140 | + |
| 141 | +def test_save_postponned_constant(prepare_ir_paths): |
| 142 | + maker = Maker() |
| 143 | + model = create_model(maker) |
| 144 | + assert maker.called_times() == 0 |
| 145 | + |
| 146 | + model_export_file_name, weights_export_file_name = prepare_ir_paths |
| 147 | + ov.save_model(model, model_export_file_name, compress_to_fp16=False) |
| 148 | + |
| 149 | + assert maker.called_times() == 1 |
| 150 | + |
| 151 | + |
| 152 | +def test_save_postponned_constant_twice(prepare_ir_paths): |
| 153 | + maker = Maker() |
| 154 | + model = create_model(maker) |
| 155 | + assert maker.called_times() == 0 |
| 156 | + |
| 157 | + model_export_file_name, weights_export_file_name = prepare_ir_paths |
| 158 | + ov.save_model(model, model_export_file_name, compress_to_fp16=False) |
| 159 | + assert maker.called_times() == 1 |
| 160 | + ov.save_model(model, model_export_file_name, compress_to_fp16=False) |
| 161 | + assert maker.called_times() == 2 |
| 162 | + |
| 163 | + |
| 164 | +def test_serialize_postponned_constant(prepare_ir_paths): |
| 165 | + maker = Maker() |
| 166 | + model = create_model(maker) |
| 167 | + assert maker.called_times() == 0 |
| 168 | + |
| 169 | + model_export_file_name, weights_export_file_name = prepare_ir_paths |
| 170 | + ov.serialize(model, model_export_file_name, weights_export_file_name) |
| 171 | + assert maker.called_times() == 1 |
| 172 | + |
| 173 | + |
| 174 | +def test_infer_postponned_constant(): |
| 175 | + maker = Maker() |
| 176 | + model = create_model(maker) |
| 177 | + assert maker.called_times() == 0 |
| 178 | + |
| 179 | + compiled_model = ov.compile_model(model, "CPU") |
| 180 | + assert isinstance(compiled_model, ov.CompiledModel) |
| 181 | + |
| 182 | + request = compiled_model.create_infer_request() |
| 183 | + input_data = np.ones([1, 2, 1, 2], dtype=np.float32) |
| 184 | + input_tensor = ov.Tensor(input_data) |
| 185 | + |
| 186 | + results = request.infer({"data": input_tensor}) |
| 187 | + assert maker.called_times() == 1 |
| 188 | + |
| 189 | + expected_output = np.array([4, 5, 6, 7], dtype=np.float32).reshape(1, 2, 1, 2) |
| 190 | + assert np.allclose(results[list(results)[0]], expected_output, 1e-4, 1e-4) |
0 commit comments