14
14
#include " openvino/op/concat.hpp"
15
15
#include " openvino/op/divide.hpp"
16
16
#include " openvino/op/gather.hpp"
17
+ #include " openvino/op/greater_eq.hpp"
17
18
#include " openvino/op/multiply.hpp"
18
19
#include " openvino/op/paged_attention.hpp"
19
20
#include " openvino/op/parameter.hpp"
21
+ #include " openvino/op/range.hpp"
20
22
#include " openvino/op/reshape.hpp"
21
23
#include " openvino/op/scaled_dot_product_attention.hpp"
22
24
#include " openvino/op/select.hpp"
28
30
#include " openvino/op/transpose.hpp"
29
31
#include " openvino/op/unsqueeze.hpp"
30
32
#include " openvino/op/variadic_split.hpp"
33
+ #include " openvino/pass/pattern/op/optional.hpp"
31
34
#include " openvino/pass/pattern/op/or.hpp"
32
35
#include " openvino/pass/pattern/op/wrap_type.hpp"
33
36
#include " transformations/utils/utils.hpp"
@@ -173,6 +176,27 @@ static std::shared_ptr<ov::Node> handle_baichuan2_13b_alibi(
173
176
return res_alibi_slopes;
174
177
}
175
178
179
+ static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> handle_phi3_sliding_window () {
180
+ using namespace ov ::pass::pattern;
181
+
182
+ auto offset = wrap_type<v0::Constant>();
183
+ auto t196 = wrap_type<v1::Add>({any_input (), offset});
184
+ auto t197 = pattern::optional<v0::Convert>(t196);
185
+ auto t200 = pattern::wrap_type<v4::Range>({t197, any_input (), any_input ()});
186
+ auto t201 = pattern::wrap_type<v0::Unsqueeze>({t200, any_input ()});
187
+ auto t202 = pattern::wrap_type<v1::GreaterEqual>({any_input (), t201});
188
+ auto t208 = pattern::wrap_type<v1::Select>({t202, any_input (), any_input ()});
189
+ auto t209 = pattern::wrap_type<v1::Subtract>({any_input (), t208});
190
+ auto t210 = pattern::optional<v0::Convert>(t209);
191
+ auto t211 = pattern::wrap_type<v1::Select>({t210, any_input (), any_input ()});
192
+ auto t213 = pattern::wrap_type<v0::Unsqueeze>({t211, any_input ()});
193
+ auto t214 = pattern::wrap_type<v0::Unsqueeze>({t213, any_input ()});
194
+ auto t218 = pattern::wrap_type<v3::Broadcast>({t214, any_input ()});
195
+ auto t219 = pattern::wrap_type<v1::Select>({any_input (), any_input (), t218});
196
+ auto mask = pattern::wrap_type<v8::Slice>({t219, any_input (), any_input (), any_input (), any_input ()});
197
+ return {mask, offset};
198
+ }
199
+
176
200
// Exactly copied the function from another file. Maybe should be moved to some general file
177
201
static std::shared_ptr<v0::Parameter> setName (std::shared_ptr<v0::Parameter> node, const std::string& name) {
178
202
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
@@ -207,7 +231,6 @@ static node_tuple kv_read_and_concat(ov::Output<ov::Node> kv_current) {
207
231
208
232
ov::pass::StateManagementPattern::StateManagementPattern (ParameterVector& kv_parameters,
209
233
ParameterVector& model_remaining_params,
210
- const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
211
234
ParameterVector& parameters_to_remove,
212
235
int & layer_index,
213
236
Output<Node> max_context_len,
@@ -297,15 +320,20 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
297
320
std::shared_ptr<ov::Node> baichuan2_13b_alibi, baichuan2_13b_alibi_mask;
298
321
std::tie (baichuan2_13b_alibi, baichuan2_13b_alibi_mask) = baichuan2_13b_alibi_pattern ();
299
322
323
+ // Phi3-xxx-instruct case
324
+ std::shared_ptr<ov::Node> phi3_mask, phi3_offset;
325
+ std::tie (phi3_mask, phi3_offset) = handle_phi3_sliding_window ();
326
+
300
327
auto q = pattern::any_input ();
301
328
auto scale_input = pattern::any_input ();
302
329
303
330
auto k_to_sdpa =
304
331
std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
305
332
auto v_to_sdpa =
306
333
std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});
334
+
307
335
auto mask_to_sdpa = std::make_shared<pattern::op::Or>(
308
- OutputVector{general_alibi_mask, jais_alibi_mask, baichuan2_13b_alibi_mask, pattern::any_input ()});
336
+ OutputVector{phi3_mask, general_alibi_mask, jais_alibi_mask, baichuan2_13b_alibi_mask, pattern::any_input ()});
309
337
310
338
auto sdpa_with_4_inputs =
311
339
pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
@@ -317,7 +345,6 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
317
345
ov::matcher_pass_callback callback = [=,
318
346
&kv_parameters,
319
347
&model_remaining_params,
320
- &sliding_window,
321
348
¶meters_to_remove,
322
349
&block_indices_inputs_for_each_layer,
323
350
&score_results,
@@ -492,6 +519,18 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
492
519
493
520
OutputVector pa_arguments = {q_reshape, k_reshape, v_reshape, k_parameter, v_parameter};
494
521
pa_arguments.insert (pa_arguments.end (), model_remaining_params.begin (), model_remaining_params.end ());
522
+
523
+ std::shared_ptr<Node> sliding_window;
524
+ if (pattern_map.count (phi3_offset)) {
525
+ auto offset = pattern_map.at (phi3_offset).get_node_shared_ptr ();
526
+ if (offset->get_element_type () != element::i32) {
527
+ offset = std::make_shared<v0::Convert>(offset, element::i32);
528
+ }
529
+ sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create (element::i32, Shape{}, {2 }), offset);
530
+ } else {
531
+ sliding_window = v0::Constant::create (element::i32, Shape{}, {0 });
532
+ }
533
+
495
534
std::initializer_list<std::shared_ptr<Node>> additional_params = {scale,
496
535
sliding_window,
497
536
alibi_slopes,
0 commit comments