diff --git a/notebooks/214-grammar-correction/utils.py b/notebooks/214-grammar-correction/utils.py index a48d3fe3967..e3989d120c5 100644 --- a/notebooks/214-grammar-correction/utils.py +++ b/notebooks/214-grammar-correction/utils.py @@ -7,37 +7,13 @@ from jiwer import wer, wer_standardize from nncf.quantization.range_estimator import RangeEstimatorParameters, StatisticsCollectorParameters, StatisticsType from optimum.intel import OVModelForSeq2SeqLM +from optimum.intel.openvino.quantization import InferRequestWrapper from pathlib import Path from tqdm.auto import tqdm from typing import List, Dict from transformers import Pipeline, pipeline, PreTrainedTokenizer CALIBRATION_DATASET_SIZE = 10 -COLLECT_CALIBRATION_DATA = False - - -@contextmanager -def calibration_data_collection(): - global COLLECT_CALIBRATION_DATA - try: - COLLECT_CALIBRATION_DATA = True - yield - finally: - COLLECT_CALIBRATION_DATA = False - - -def wrap_for_data_collection(ov_decoder, calibration_data): - original_fn = ov_decoder.request.start_async - if original_fn.__name__ == "wrapper": - # Already wrapped - return - - def wrapper(*args, **kwargs): - inputs = kwargs.get("inputs", args[0]) - if COLLECT_CALIBRATION_DATA: - calibration_data.append(inputs) - return original_fn(*args, **kwargs) - ov_decoder.request.start_async = wrapper def collect_calibration_data(grammar_corrector_pipe_fp32: Pipeline, calibration_dataset_size: int) -> List[Dict]: @@ -45,15 +21,17 @@ def collect_calibration_data(grammar_corrector_pipe_fp32: Pipeline, calibration_ ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past # Wrap decoder inference for data collection - wrap_for_data_collection(ov_decoder, calibration_data) + original_infer_request = ov_decoder.request + ov_decoder.request = InferRequestWrapper(original_infer_request, calibration_data) # Run inference for data collection calibration_dataset = datasets.load_dataset("jfleg", split="validation").shuffle(seed=42)[:calibration_dataset_size] - with calibration_data_collection(): - for data_item in tqdm(calibration_dataset["sentence"], total=calibration_dataset_size, - desc="Collecting calibration data"): - grammar_corrector_pipe_fp32(data_item) - assert isinstance(calibration_data[0], dict) + for data_item in tqdm(calibration_dataset["sentence"], total=calibration_dataset_size, + desc="Collecting calibration data"): + grammar_corrector_pipe_fp32(data_item) + + ov_decoder.request = original_infer_request + return calibration_data