Skip to content

Commit

Permalink
[214-grammar-correction] Switch to InferRequestWrapper from for calib…
Browse files Browse the repository at this point in the history
…ration data collection (#1777)

Switch to InferRequestWrapper from optimum intel for calibration data
collection. The current way has an issue similar to the one described at
huggingface/optimum-intel#577 .
  • Loading branch information
nikita-savelyevv authored Mar 3, 2024
1 parent 7697a7d commit 19f10e5
Showing 1 changed file with 9 additions and 31 deletions.
40 changes: 9 additions & 31 deletions notebooks/214-grammar-correction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,31 @@
from jiwer import wer, wer_standardize
from nncf.quantization.range_estimator import RangeEstimatorParameters, StatisticsCollectorParameters, StatisticsType
from optimum.intel import OVModelForSeq2SeqLM
from optimum.intel.openvino.quantization import InferRequestWrapper
from pathlib import Path
from tqdm.auto import tqdm
from typing import List, Dict
from transformers import Pipeline, pipeline, PreTrainedTokenizer

CALIBRATION_DATASET_SIZE = 10
COLLECT_CALIBRATION_DATA = False


@contextmanager
def calibration_data_collection():
global COLLECT_CALIBRATION_DATA
try:
COLLECT_CALIBRATION_DATA = True
yield
finally:
COLLECT_CALIBRATION_DATA = False


def wrap_for_data_collection(ov_decoder, calibration_data):
original_fn = ov_decoder.request.start_async
if original_fn.__name__ == "wrapper":
# Already wrapped
return

def wrapper(*args, **kwargs):
inputs = kwargs.get("inputs", args[0])
if COLLECT_CALIBRATION_DATA:
calibration_data.append(inputs)
return original_fn(*args, **kwargs)
ov_decoder.request.start_async = wrapper


def collect_calibration_data(grammar_corrector_pipe_fp32: Pipeline, calibration_dataset_size: int) -> List[Dict]:
calibration_data = []
ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past

# Wrap decoder inference for data collection
wrap_for_data_collection(ov_decoder, calibration_data)
original_infer_request = ov_decoder.request
ov_decoder.request = InferRequestWrapper(original_infer_request, calibration_data)

# Run inference for data collection
calibration_dataset = datasets.load_dataset("jfleg", split="validation").shuffle(seed=42)[:calibration_dataset_size]
with calibration_data_collection():
for data_item in tqdm(calibration_dataset["sentence"], total=calibration_dataset_size,
desc="Collecting calibration data"):
grammar_corrector_pipe_fp32(data_item)
assert isinstance(calibration_data[0], dict)
for data_item in tqdm(calibration_dataset["sentence"], total=calibration_dataset_size,
desc="Collecting calibration data"):
grammar_corrector_pipe_fp32(data_item)

ov_decoder.request = original_infer_request

return calibration_data


Expand Down

0 comments on commit 19f10e5

Please sign in to comment.