12
12
from dataclasses import dataclass
13
13
from typing import Any , Callable , Iterable , List , Optional , Tuple , TypeVar , Union
14
14
15
- from nncf .common .factory import EngineFactory
16
15
from nncf .common .logging import nncf_logger
17
16
from nncf .common .utils .backend import BackendType
18
17
from nncf .common .utils .backend import get_backend
19
18
from nncf .common .utils .timer import timer
20
19
from nncf .data .dataset import Dataset
20
+ from nncf .quantization .algorithms .accuracy_control .backend import PreparedModel
21
21
22
22
TModel = TypeVar ("TModel" )
23
- TPModel = TypeVar ("TPModel" )
24
23
TTensor = TypeVar ("TTensor" )
25
24
26
25
@@ -111,7 +110,7 @@ def is_metric_mode(self) -> bool:
111
110
"""
112
111
return self ._metric_mode
113
112
114
- def prepare_model_for_inference (self , model : TModel ) -> TPModel :
113
+ def prepare_model (self , model : TModel ) -> PreparedModel :
115
114
"""
116
115
Prepares model for inference.
117
116
@@ -121,21 +120,19 @@ def prepare_model_for_inference(self, model: TModel) -> TPModel:
121
120
backend = get_backend (model )
122
121
123
122
if backend == BackendType .OPENVINO :
124
- import openvino . runtime as ov
123
+ from nncf . quantization . algorithms . accuracy_control . openvino_backend import OVPreparedModel
125
124
126
- return ov . compile_model (model )
125
+ return OVPreparedModel (model )
127
126
128
- raise NotImplementedError (
129
- f"The `prepare_model_for_inference()` method is not implemented for the { backend } backend."
130
- )
127
+ raise NotImplementedError (f"The `prepare_model()` method is not implemented for the { backend } backend." )
131
128
132
- def validate_model_for_inference (
133
- self , model_for_inference : TPModel , dataset : Dataset , indices : Optional [List [int ]] = None
129
+ def validate_prepared_model (
130
+ self , prepared_model : PreparedModel , dataset : Dataset , indices : Optional [List [int ]] = None
134
131
):
135
132
"""
136
133
Validates prepared model for inference.
137
134
138
- :param model : Prepared model to validate.
135
+ :param prepared_model : Prepared model to validate.
139
136
:param dataset: Dataset to validate the model.
140
137
:param indices: Zero-based indices of data items that should be selected from
141
138
the dataset.
@@ -147,7 +144,7 @@ def validate_model_for_inference(
147
144
item.
148
145
"""
149
146
if self ._metric_mode is None :
150
- self ._metric_mode = Evaluator .determine_mode (model_for_inference , dataset , self ._validation_fn )
147
+ self ._metric_mode = Evaluator .determine_mode (prepared_model , dataset , self ._validation_fn )
151
148
152
149
if not self .is_metric_mode () and indices is not None :
153
150
raise ValueError ("The `indices` parameter can be used only if Evaluator.is_metric_mode() = True" )
@@ -156,7 +153,7 @@ def validate_model_for_inference(
156
153
if self ._enable_iteration_count :
157
154
validation_dataset = IterationCounter (validation_dataset )
158
155
159
- metric , values_for_each_item = self ._validation_fn (model_for_inference , validation_dataset )
156
+ metric , values_for_each_item = self ._validation_fn (prepared_model . model_for_inference , validation_dataset )
160
157
161
158
self ._num_passed_iterations = validation_dataset .num_iterations if self ._enable_iteration_count else 0
162
159
@@ -189,20 +186,20 @@ def validate(
189
186
Otherwise, if the condition is false, it represents list of logits for each
190
187
item.
191
188
"""
192
- model_for_inference = self .prepare_model_for_inference (model )
193
- return self .validate_model_for_inference ( model_for_inference , dataset , indices )
189
+ prepared_model = self .prepare_model (model )
190
+ return self .validate_prepared_model ( prepared_model , dataset , indices )
194
191
195
192
@staticmethod
196
193
def determine_mode (
197
- model_for_inference : TPModel ,
194
+ prepared_model : PreparedModel ,
198
195
dataset : Dataset ,
199
196
validation_fn : Callable [[Any , Iterable [Any ]], Tuple [float , Union [None , List [float ], List [List [TTensor ]]]]],
200
197
) -> bool :
201
198
"""
202
199
Determines mode based on the type of returned value from the
203
200
validation function.
204
201
205
- :param model_for_inference : Model to validate.
202
+ :param prepared_model : Model to validate.
206
203
:param dataset: Dataset to validate the model.
207
204
:param validation_fn: Validation function to validate model.
208
205
:return: A boolean indicator where `True` means that the `Evaluator` collects
@@ -214,7 +211,7 @@ def determine_mode(
214
211
data_item = dataset .get_data ([0 ])
215
212
216
213
try :
217
- metric_value , values_for_each_item = validation_fn (model_for_inference , data_item )
214
+ metric_value , values_for_each_item = validation_fn (prepared_model . model_for_inference , data_item )
218
215
except Exception :
219
216
metric_mode = False
220
217
@@ -261,15 +258,15 @@ def determine_mode(
261
258
262
259
return metric_mode
263
260
264
- def collect_values_for_each_item_using_model_for_inference (
265
- self , model_for_inference : TPModel , dataset : Dataset , indices : Optional [List [int ]] = None
261
+ def collect_values_for_each_item_using_prepared_model (
262
+ self , prepared_model : PreparedModel , dataset : Dataset , indices : Optional [List [int ]] = None
266
263
) -> Union [List [float ], List [List [TTensor ]]]:
267
264
"""
268
265
Collects value for each item from the dataset using prepared model for inference.
269
266
If `is_metric_mode()` returns `True` then i-th value is a metric for i-th data item.
270
267
It is an output of the model for i-th data item otherwise.
271
268
272
- :param model : Model to infer.
269
+ :param prepared_model : Model to infer.
273
270
:param dataset: Dataset to collect values.
274
271
:param indices: The zero-based indices of data items that should be selected from
275
272
the dataset.
@@ -278,15 +275,14 @@ def collect_values_for_each_item_using_model_for_inference(
278
275
if self ._metric_mode :
279
276
# Collect metrics for each item
280
277
values_for_each_item = [
281
- self ._validation_fn (model_for_inference , [data_item ])[0 ] for data_item in dataset .get_data (indices )
278
+ self ._validation_fn (prepared_model .model_for_inference , [data_item ])[0 ]
279
+ for data_item in dataset .get_data (indices )
282
280
]
283
281
else :
284
282
# Collect outputs for each item
285
- engine = EngineFactory .create (model_for_inference )
286
-
287
283
values_for_each_item = []
288
284
for data_item in dataset .get_inference_data (indices ):
289
- logits = engine . infer (data_item )
285
+ logits = prepared_model (data_item )
290
286
values_for_each_item .append (list (logits .values ()))
291
287
292
288
self ._num_passed_iterations = len (values_for_each_item ) if self ._enable_iteration_count else 0
@@ -307,8 +303,8 @@ def collect_values_for_each_item(
307
303
the dataset.
308
304
:return: Collected values.
309
305
"""
310
- model_for_inference = self .prepare_model_for_inference (model )
311
- return self .collect_values_for_each_item_using_model_for_inference ( model_for_inference , dataset , indices )
306
+ prepared_model = self .prepare_model (model )
307
+ return self .collect_values_for_each_item_using_prepared_model ( prepared_model , dataset , indices )
312
308
313
309
def collect_metric_results (self , model : TModel , dataset : Dataset , model_name : str = "" ) -> MetricResults :
314
310
"""
@@ -322,18 +318,16 @@ def collect_metric_results(self, model: TModel, dataset: Dataset, model_name: st
322
318
nncf_logger .info (f"Validation of { model_name } model was started" )
323
319
324
320
with timer () as preparation_time :
325
- model_for_inference = self .prepare_model_for_inference (model )
321
+ prepared_model = self .prepare_model (model )
326
322
327
323
with timer () as validation_time :
328
- metric , values_for_each_item = self .validate_model_for_inference ( model_for_inference , dataset )
324
+ metric , values_for_each_item = self .validate_prepared_model ( prepared_model , dataset )
329
325
330
326
nncf_logger .info (f"Metric of { model_name } model: { metric } " )
331
327
332
328
if values_for_each_item is None :
333
329
nncf_logger .info (f"Collecting values for each data item using the { model_name } model" )
334
330
with timer ():
335
- values_for_each_item = self .collect_values_for_each_item_using_model_for_inference (
336
- model_for_inference , dataset
337
- )
331
+ values_for_each_item = self .collect_values_for_each_item_using_prepared_model (prepared_model , dataset )
338
332
339
333
return MetricResults (metric , values_for_each_item , preparation_time (), validation_time ())
0 commit comments