1212from dataclasses import dataclass
1313from typing import Any , Callable , Iterable , List , Optional , Tuple , TypeVar , Union
1414
15- from nncf .common .factory import EngineFactory
1615from nncf .common .logging import nncf_logger
1716from nncf .common .utils .backend import BackendType
1817from nncf .common .utils .backend import get_backend
1918from nncf .common .utils .timer import timer
2019from nncf .data .dataset import Dataset
20+ from nncf .quantization .algorithms .accuracy_control .backend import PreparedModel
2121
2222TModel = TypeVar ("TModel" )
23- TPModel = TypeVar ("TPModel" )
2423TTensor = TypeVar ("TTensor" )
2524
2625
@@ -111,7 +110,7 @@ def is_metric_mode(self) -> bool:
111110 """
112111 return self ._metric_mode
113112
114- def prepare_model_for_inference (self , model : TModel ) -> TPModel :
113+ def prepare_model (self , model : TModel ) -> PreparedModel :
115114 """
116115 Prepares model for inference.
117116
@@ -121,21 +120,19 @@ def prepare_model_for_inference(self, model: TModel) -> TPModel:
121120 backend = get_backend (model )
122121
123122 if backend == BackendType .OPENVINO :
124- import openvino . runtime as ov
123+ from nncf . quantization . algorithms . accuracy_control . openvino_backend import OVPreparedModel
125124
126- return ov . compile_model (model )
125+ return OVPreparedModel (model )
127126
128- raise NotImplementedError (
129- f"The `prepare_model_for_inference()` method is not implemented for the { backend } backend."
130- )
127+ raise NotImplementedError (f"The `prepare_model()` method is not implemented for the { backend } backend." )
131128
132- def validate_model_for_inference (
133- self , model_for_inference : TPModel , dataset : Dataset , indices : Optional [List [int ]] = None
129+ def validate_prepared_model (
130+ self , prepared_model : PreparedModel , dataset : Dataset , indices : Optional [List [int ]] = None
134131 ):
135132 """
136133 Validates prepared model for inference.
137134
138- :param model : Prepared model to validate.
135+ :param prepared_model : Prepared model to validate.
139136 :param dataset: Dataset to validate the model.
140137 :param indices: Zero-based indices of data items that should be selected from
141138 the dataset.
@@ -147,7 +144,7 @@ def validate_model_for_inference(
147144 item.
148145 """
149146 if self ._metric_mode is None :
150- self ._metric_mode = Evaluator .determine_mode (model_for_inference , dataset , self ._validation_fn )
147+ self ._metric_mode = Evaluator .determine_mode (prepared_model , dataset , self ._validation_fn )
151148
152149 if not self .is_metric_mode () and indices is not None :
153150 raise ValueError ("The `indices` parameter can be used only if Evaluator.is_metric_mode() = True" )
@@ -156,7 +153,7 @@ def validate_model_for_inference(
156153 if self ._enable_iteration_count :
157154 validation_dataset = IterationCounter (validation_dataset )
158155
159- metric , values_for_each_item = self ._validation_fn (model_for_inference , validation_dataset )
156+ metric , values_for_each_item = self ._validation_fn (prepared_model . model_for_inference , validation_dataset )
160157
161158 self ._num_passed_iterations = validation_dataset .num_iterations if self ._enable_iteration_count else 0
162159
@@ -189,20 +186,20 @@ def validate(
189186 Otherwise, if the condition is false, it represents list of logits for each
190187 item.
191188 """
192- model_for_inference = self .prepare_model_for_inference (model )
193- return self .validate_model_for_inference ( model_for_inference , dataset , indices )
189+ prepared_model = self .prepare_model (model )
190+ return self .validate_prepared_model ( prepared_model , dataset , indices )
194191
195192 @staticmethod
196193 def determine_mode (
197- model_for_inference : TPModel ,
194+ prepared_model : PreparedModel ,
198195 dataset : Dataset ,
199196 validation_fn : Callable [[Any , Iterable [Any ]], Tuple [float , Union [None , List [float ], List [List [TTensor ]]]]],
200197 ) -> bool :
201198 """
202199 Determines mode based on the type of returned value from the
203200 validation function.
204201
205- :param model_for_inference : Model to validate.
202+ :param prepared_model : Model to validate.
206203 :param dataset: Dataset to validate the model.
207204 :param validation_fn: Validation function to validate model.
208205 :return: A boolean indicator where `True` means that the `Evaluator` collects
@@ -214,7 +211,7 @@ def determine_mode(
214211 data_item = dataset .get_data ([0 ])
215212
216213 try :
217- metric_value , values_for_each_item = validation_fn (model_for_inference , data_item )
214+ metric_value , values_for_each_item = validation_fn (prepared_model . model_for_inference , data_item )
218215 except Exception :
219216 metric_mode = False
220217
@@ -261,15 +258,15 @@ def determine_mode(
261258
262259 return metric_mode
263260
264- def collect_values_for_each_item_using_model_for_inference (
265- self , model_for_inference : TPModel , dataset : Dataset , indices : Optional [List [int ]] = None
261+ def collect_values_for_each_item_using_prepared_model (
262+ self , prepared_model : PreparedModel , dataset : Dataset , indices : Optional [List [int ]] = None
266263 ) -> Union [List [float ], List [List [TTensor ]]]:
267264 """
268265 Collects value for each item from the dataset using prepared model for inference.
269266 If `is_metric_mode()` returns `True` then i-th value is a metric for i-th data item.
270267 It is an output of the model for i-th data item otherwise.
271268
272- :param model : Model to infer.
269+ :param prepared_model : Model to infer.
273270 :param dataset: Dataset to collect values.
274271 :param indices: The zero-based indices of data items that should be selected from
275272 the dataset.
@@ -278,15 +275,14 @@ def collect_values_for_each_item_using_model_for_inference(
278275 if self ._metric_mode :
279276 # Collect metrics for each item
280277 values_for_each_item = [
281- self ._validation_fn (model_for_inference , [data_item ])[0 ] for data_item in dataset .get_data (indices )
278+ self ._validation_fn (prepared_model .model_for_inference , [data_item ])[0 ]
279+ for data_item in dataset .get_data (indices )
282280 ]
283281 else :
284282 # Collect outputs for each item
285- engine = EngineFactory .create (model_for_inference )
286-
287283 values_for_each_item = []
288284 for data_item in dataset .get_inference_data (indices ):
289- logits = engine . infer (data_item )
285+ logits = prepared_model (data_item )
290286 values_for_each_item .append (list (logits .values ()))
291287
292288 self ._num_passed_iterations = len (values_for_each_item ) if self ._enable_iteration_count else 0
@@ -307,8 +303,8 @@ def collect_values_for_each_item(
307303 the dataset.
308304 :return: Collected values.
309305 """
310- model_for_inference = self .prepare_model_for_inference (model )
311- return self .collect_values_for_each_item_using_model_for_inference ( model_for_inference , dataset , indices )
306+ prepared_model = self .prepare_model (model )
307+ return self .collect_values_for_each_item_using_prepared_model ( prepared_model , dataset , indices )
312308
313309 def collect_metric_results (self , model : TModel , dataset : Dataset , model_name : str = "" ) -> MetricResults :
314310 """
@@ -322,18 +318,16 @@ def collect_metric_results(self, model: TModel, dataset: Dataset, model_name: st
322318 nncf_logger .info (f"Validation of { model_name } model was started" )
323319
324320 with timer () as preparation_time :
325- model_for_inference = self .prepare_model_for_inference (model )
321+ prepared_model = self .prepare_model (model )
326322
327323 with timer () as validation_time :
328- metric , values_for_each_item = self .validate_model_for_inference ( model_for_inference , dataset )
324+ metric , values_for_each_item = self .validate_prepared_model ( prepared_model , dataset )
329325
330326 nncf_logger .info (f"Metric of { model_name } model: { metric } " )
331327
332328 if values_for_each_item is None :
333329 nncf_logger .info (f"Collecting values for each data item using the { model_name } model" )
334330 with timer ():
335- values_for_each_item = self .collect_values_for_each_item_using_model_for_inference (
336- model_for_inference , dataset
337- )
331+ values_for_each_item = self .collect_values_for_each_item_using_prepared_model (prepared_model , dataset )
338332
339333 return MetricResults (metric , values_for_each_item , preparation_time (), validation_time ())
0 commit comments