Skip to content

Commit 3e0b5bb

Browse files
eaidovaandreyanufr
andauthored
small refactoring (#2411)
Co-authored-by: Andrei Anufriev <[email protected]>
1 parent 45f3b9a commit 3e0b5bb

File tree

3 files changed

+171
-39
lines changed

3 files changed

+171
-39
lines changed

notebooks/mllama-3.2/data_preprocessing.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,35 +39,6 @@ def get_pil_from_url(url):
3939
return image.convert("RGB")
4040

4141

42-
# def collate_fn_llm(example, image_column="image_url", text_column="caption"):
43-
# """
44-
# Preprocesses an example by loading and transforming image and text data.
45-
# Checks if the text data in the example is valid by calling the `check_text_data` function.
46-
# Downloads the image specified by the URL in the image_column by calling the `get_pil_from_url` function.
47-
# If there is any error during the download process, returns None.
48-
# Returns the preprocessed inputs with transformed image and text data.
49-
# """
50-
# assert len(example) == 1
51-
# example = example[0]
52-
53-
# if not check_text_data(example[text_column]):
54-
# raise ValueError("Text data is not valid")
55-
56-
# url = example[image_column]
57-
# try:
58-
# image = get_pil_from_url(url)
59-
# h, w = image.size
60-
# if h == 1 or w == 1:
61-
# return None
62-
# except Exception:
63-
# return None
64-
65-
# inputs = processor(text="<|image|><|begin_of_text|>"+example[text_column], images=image, return_tensors="pt", padding=True)
66-
# if inputs['input_ids'].shape[1] > max_length:
67-
# return None
68-
# return inputs
69-
70-
7142
def prepare_calibration_data_vision(dataloader, init_steps):
7243
"""
7344
This function prepares calibration data from a dataloader for a specified number of initialization steps.

notebooks/mllama-3.2/ov_mllama_compression.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,3 @@ def compress(
133133
print(f"Model compression finished. Compressed model can be found in {saving_path}")
134134

135135
return saving_path
136-
137-
138-
# model_id = "Llama-3.2-11B-Vision-Instruct/OV"
139-
# processor = AutoProcessor.from_pretrained(model_id)
140-
141-
# compress(model_id, processor)

notebooks/mllama-3.2/ov_mllama_helper.py

Lines changed: 171 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from transformers.models.llama.modeling_llama import repeat_kv
55
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable
66
from typing import Optional, Union, List, Tuple, Dict
7-
from optimum.exporters.openvino.stateful import patch_stateful
87
from transformers.generation import GenerationMixin
98
from transformers.modeling_outputs import ModelOutput
109
import openvino.runtime.opset13 as ops
@@ -83,6 +82,176 @@ def callback(matcher: Matcher) -> bool:
8382
}
8483

8584

85+
def model_has_state(ov_model: ov.Model):
86+
return len(ov_model.get_sinks()) > 0
87+
88+
89+
def model_has_input_output_name(ov_model: ov.Model, name: str):
90+
"""
91+
Helper function for checking that model has specified input or output name
92+
93+
Parameters:
94+
ov_model (ov.Model):
95+
name (str):
96+
name of input or output
97+
98+
Returns:
99+
True if input or output with requested name exists else False
100+
"""
101+
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])
102+
103+
104+
def fuse_cache_reorder(
105+
ov_model: ov.Model,
106+
not_kv_inputs: List[str],
107+
key_value_input_names: List[str],
108+
gather_dim: int,
109+
):
110+
"""
111+
Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
112+
113+
Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
114+
Should be run before make_stateful. Implements optimumum's _reorder_cache
115+
inside the model in the beginning of each iteration.
116+
Gather works along given gather_dim dimension that may vary from model to model.
117+
KV-cache inputs are identified based on names in key_value_input_names.
118+
Append the new beam_idx parameter to not_kv_inputs.
119+
120+
Parameters:
121+
ov_model (`ov.Model`):
122+
openvino model for processing
123+
not_kv_inputs (`List[str]`):
124+
list of input nodes in model that not related to past key values
125+
key_value_input_names (`List[str]`):
126+
list of names for key value input layers
127+
gather_dim (int):
128+
dimension for gathering cache during reorder pass
129+
"""
130+
131+
if model_has_input_output_name(ov_model, "beam_idx"):
132+
raise ValueError("Model already has fused cache")
133+
input_batch = ov_model.input("input_ids").get_partial_shape()[0]
134+
beam_idx = ops.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
135+
beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted?
136+
ov_model.add_parameters([beam_idx])
137+
not_kv_inputs.append(ov_model.inputs[-1])
138+
# Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
139+
for input_name in key_value_input_names:
140+
parameter_output_port = ov_model.input(input_name)
141+
consumers = parameter_output_port.get_target_inputs()
142+
gather = ops.gather(parameter_output_port, beam_idx, ops.constant(gather_dim))
143+
for consumer in consumers:
144+
consumer.replace_source_output(gather.output(0))
145+
ov_model.validate_nodes_and_infer_types()
146+
147+
148+
def build_state_initializer(ov_model: ov.Model, batch_dim: int):
149+
"""
150+
Build initialization ShapeOf Expression for all ReadValue ops
151+
152+
Parameters:
153+
ov_model (ov.Model):
154+
openvino model
155+
batch_dim (int):
156+
index of dimension corresponding to batch size
157+
"""
158+
input_ids = ov_model.input("input_ids")
159+
batch = ops.gather(
160+
ops.shape_of(input_ids, output_type="i64"),
161+
ops.constant([0]),
162+
ops.constant(0),
163+
)
164+
for op in ov_model.get_ops():
165+
if op.get_type_name() == "ReadValue":
166+
dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
167+
dims[batch_dim] = batch
168+
dims = [(ops.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims]
169+
shape = ops.concat(dims, axis=0)
170+
broadcast = ops.broadcast(ops.constant(0.0, dtype=op.get_output_element_type(0)), shape)
171+
op.set_arguments([broadcast])
172+
ov_model.validate_nodes_and_infer_types()
173+
174+
175+
def make_stateful(
176+
ov_model: ov.Model,
177+
not_kv_inputs: List[str],
178+
key_value_input_names: List[str],
179+
key_value_output_names: List[str],
180+
batch_dim: int,
181+
num_attention_heads: int,
182+
num_beams_and_batch: int = None,
183+
):
184+
"""
185+
Hides kv-cache inputs and outputs inside the model as variables.
186+
187+
Parameters:
188+
ov_model (ov.Model):
189+
openvino model
190+
not_kv_inputs (`List[str]`):
191+
list of input nodes in model that not related to past key values
192+
key_value_input_names (`List[str]`):
193+
list of names for key value input layers
194+
key_value_output_names (`List[str]`):
195+
list of names for key value input layers
196+
batch_dim (int):
197+
index of batch dimension in key value layers
198+
num_attention_heads (int):
199+
number of attention heads for batch dimension initialization
200+
num_beams_an_batch (int):
201+
precalculated number of beams and batch for shapes initialization
202+
"""
203+
from openvino._offline_transformations import apply_make_stateful_transformation
204+
205+
input_output_map = {}
206+
207+
if num_beams_and_batch is not None:
208+
# Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue
209+
for input in not_kv_inputs:
210+
shape = input.get_partial_shape()
211+
if shape.rank.get_length() <= 2: # == 1 for beam_index
212+
shape[0] = num_beams_and_batch
213+
input.get_node().set_partial_shape(shape)
214+
for kv_name_pair in zip(key_value_input_names, key_value_output_names):
215+
input_output_map[kv_name_pair[0]] = kv_name_pair[1]
216+
if num_beams_and_batch is not None:
217+
input = ov_model.input(kv_name_pair[0])
218+
shape = input.get_partial_shape()
219+
shape[batch_dim] = num_beams_and_batch * num_attention_heads
220+
input.get_node().set_partial_shape(shape)
221+
222+
if num_beams_and_batch is not None:
223+
# Re-validation model if shapes are altered above
224+
ov_model.validate_nodes_and_infer_types()
225+
226+
apply_make_stateful_transformation(ov_model, input_output_map)
227+
if num_beams_and_batch is None:
228+
build_state_initializer(ov_model, batch_dim)
229+
230+
231+
def patch_stateful(ov_model):
232+
key_value_input_names = [key_name for key in ov_model.inputs for key_name in key.get_names() if "past_key_values" in key_name]
233+
key_value_output_names = [key_name for key in ov_model.outputs for key_name in key.get_names() if "present" in key_name]
234+
not_kv_inputs = [input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())]
235+
if not key_value_input_names or not key_value_output_names:
236+
return
237+
not_kv_inputs = [input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())]
238+
if not key_value_input_names or not key_value_output_names:
239+
return
240+
batch_dim = 0
241+
num_attention_heads = 1
242+
243+
fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
244+
make_stateful(
245+
ov_model,
246+
not_kv_inputs,
247+
key_value_input_names,
248+
key_value_output_names,
249+
batch_dim,
250+
num_attention_heads,
251+
None,
252+
)
253+
254+
86255
def convert_mllama(model_id, out_dir):
87256

88257
out_dir = Path(out_dir)
@@ -306,8 +475,7 @@ def cross_attn_forward(
306475
output.get_tensor().set_names({output_name})
307476

308477
ov_model.validate_nodes_and_infer_types()
309-
310-
patch_stateful(model.config.text_config, ov_model)
478+
patch_stateful(ov_model)
311479
ov.save_model(ov_model, lang_model_path)
312480
del ov_model
313481
cleanup_torchscript_cache()
@@ -785,7 +953,6 @@ def prepare_remote_tensors(self):
785953

786954

787955
if __name__ == "__main__":
788-
# convert_mllama("/home/ea/llama3.2/Llama-3.2-11B-Vision-Instruct", "Llama-3.2-11B-Vision-Instruct/OV")
789956
model_id = "Llama-3.2-11B-Vision-Instruct/OV"
790957
LANGUAGE_MODEL_NAME = "llm_int4_asym_r10_gs64_max_activation_variance_all_layers.xml"
791958
IMAGE_ENCODER_NAME = "openvino_vision_encoder.xml"

0 commit comments

Comments
 (0)