Skip to content

Commit b393399

Browse files
authored
transit on stateful seq2seq models (#2667)
1 parent b22d7e6 commit b393399

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

notebooks/distil-whisper-asr/distil-whisper-asr.ipynb

+7-9
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,7 @@
984984
" encoder_calibration_data = []\n",
985985
" decoder_calibration_data = []\n",
986986
" ov_model.encoder.request = InferRequestWrapper(ov_model.encoder.request, encoder_calibration_data, apply_caching=True)\n",
987-
" ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request,\n",
987+
" ov_model.decoder.request = InferRequestWrapper(ov_model.decoder.request,\n",
988988
" decoder_calibration_data,\n",
989989
" apply_caching=True)\n",
990990
"\n",
@@ -996,7 +996,7 @@
996996
" ov_model.generate(input_features)\n",
997997
" finally:\n",
998998
" ov_model.encoder.request = ov_model.encoder.request.request\n",
999-
" ov_model.decoder_with_past.request = ov_model.decoder_with_past.request.request\n",
999+
" ov_model.decoder.request = ov_model.decoder.request.request\n",
10001000
"\n",
10011001
" return encoder_calibration_data, decoder_calibration_data"
10021002
]
@@ -1146,23 +1146,21 @@
11461146
" gc.collect()\n",
11471147
"\n",
11481148
" print(\"Quantizing decoder with past\")\n",
1149-
" quantized_decoder_with_past = nncf.quantize(\n",
1150-
" ov_model.decoder_with_past.model,\n",
1149+
" quantized_decoder = nncf.quantize(\n",
1150+
" ov_model.decoder.model,\n",
11511151
" nncf.Dataset(decoder_calibration_data),\n",
11521152
" subset_size=len(decoder_calibration_data),\n",
11531153
" model_type=nncf.ModelType.TRANSFORMER,\n",
11541154
" # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search\n",
11551155
" advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.95)\n",
11561156
" )\n",
1157-
" ov.save_model(quantized_decoder_with_past, quantized_model_path / \"openvino_decoder_with_past_model.xml\")\n",
1158-
" del quantized_decoder_with_past\n",
1157+
" ov.save_model(quantized_decoder_with_past, quantized_model_path / \"openvino_decoder_model.xml\")\n",
1158+
" del quantized_decoder\n",
11591159
" del decoder_calibration_data\n",
11601160
" gc.collect()\n",
11611161
"\n",
11621162
" # Copy the config file and the first-step-decoder manually\n",
11631163
" shutil.copy(model_path / \"config.json\", quantized_model_path / \"config.json\")\n",
1164-
" shutil.copy(model_path / \"openvino_decoder_model.xml\", quantized_model_path / \"openvino_decoder_model.xml\")\n",
1165-
" shutil.copy(model_path / \"openvino_decoder_model.bin\", quantized_model_path / \"openvino_decoder_model.bin\")\n",
11661164
"\n",
11671165
" quantized_ov_model = OVModelForSpeechSeq2Seq.from_pretrained(quantized_model_path, ov_config=ov_config, compile=False)\n",
11681166
" quantized_ov_model.to(device.value)\n",
@@ -1392,7 +1390,7 @@
13921390
" whole_infer_times = []\n",
13931391
" time_fn(ov_model, \"generate\", whole_infer_times)\n",
13941392
" time_fn(ov_model.encoder, \"forward\", encoder_infer_times)\n",
1395-
" time_fn(ov_model.decoder_with_past, \"forward\", decoder_with_past_infer_times)\n",
1393+
" time_fn(ov_model.decoder, \"forward\", decoder_with_past_infer_times)\n",
13961394
"\n",
13971395
" ground_truths = []\n",
13981396
" predictions = []\n",

notebooks/grammar-correction/grammar-correction.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@
954954
"grammar_corrector_pipe_fp32 = grammar_corrector_pipe\n",
955955
"grammar_corrector_pipe_int8 = None\n",
956956
"if to_quantize.value:\n",
957-
" quantized_model_path = Path(\"quantized_decoder_with_past\") / \"openvino_model.xml\"\n",
957+
" quantized_model_path = Path(\"quantized_decodet\") / \"openvino_model.xml\"\n",
958958
" grammar_corrector_pipe_int8 = get_quantized_pipeline(\n",
959959
" grammar_corrector_pipe_fp32,\n",
960960
" grammar_corrector_tokenizer,\n",
@@ -1063,7 +1063,7 @@
10631063
"\n",
10641064
"if to_quantize.value:\n",
10651065
" model_size_fp32, model_size_int8 = calculate_compression_rate(\n",
1066-
" grammar_corrector_dir / \"openvino_decoder_with_past_model.xml\",\n",
1066+
" grammar_corrector_dir / \"openvino_decoder_model.xml\",\n",
10671067
" quantized_model_path,\n",
10681068
" )"
10691069
]

notebooks/grammar-correction/utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
def collect_calibration_data(grammar_corrector_pipe_fp32: Pipeline, calibration_dataset_size: int) -> List[Dict]:
2525
calibration_data = []
26-
ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past
26+
ov_decoder = grammar_corrector_pipe_fp32.model.decoder
2727

2828
# Wrap decoder inference for data collection
2929
ov_decoder.request = InferRequestWrapper(ov_decoder.request, calibration_data, apply_caching=True)
@@ -55,7 +55,7 @@ def quantize(
5555
quantized_model = core.read_model(model=quantized_model_path)
5656
else:
5757
calibration_data = collect_calibration_data(grammar_corrector_pipe_fp32, calibration_dataset_size)
58-
ov_decoder = grammar_corrector_pipe_fp32.model.decoder_with_past
58+
ov_decoder = grammar_corrector_pipe_fp32.model.decoder
5959
quantized_model = nncf.quantize(
6060
ov_decoder.model,
6161
calibration_dataset=nncf.Dataset(calibration_data),
@@ -93,9 +93,9 @@ def get_quantized_pipeline(
9393

9494
# Load quantized model into grammar correction pipeline
9595
grammar_corrector_model_int8 = OVModelForSeq2SeqLM.from_pretrained(grammar_corrector_dir, device=device)
96-
grammar_corrector_model_int8.decoder_with_past.model = quantized_model
97-
grammar_corrector_model_int8.decoder_with_past.request = None
98-
grammar_corrector_model_int8.decoder_with_past._compile()
96+
grammar_corrector_model_int8.decoder.model = quantized_model
97+
grammar_corrector_model_int8.decoder.request = None
98+
grammar_corrector_model_int8.decoder._compile()
9999
grammar_corrector_pipe_int8 = pipeline(
100100
"text2text-generation", model=grammar_corrector_model_int8, tokenizer=grammar_corrector_tokenizer, device=torch.device("cpu")
101101
)

0 commit comments

Comments
 (0)