Skip to content

Commit f0bc62c

Browse files
authored
Update SDPA to PagedAttention transformation to support phi3 sliding window (#29608)
### Details: Update SDPA to PagedAttention transformation to support phi3 sliding window ### Tickets: - *CVS-163524*
1 parent 056e0ef commit f0bc62c

File tree

4 files changed

+535
-362
lines changed

4 files changed

+535
-362
lines changed

src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
2020
OPENVINO_MATCHER_PASS_RTTI("StateManagementPattern");
2121
StateManagementPattern(ParameterVector& kv_parameters,
2222
ParameterVector& model_remaining_params,
23-
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
2423
ParameterVector& parameters_to_remove,
2524
int& layer_index,
2625
ov::Output<Node> max_context_len,

src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
#include "openvino/op/concat.hpp"
1515
#include "openvino/op/divide.hpp"
1616
#include "openvino/op/gather.hpp"
17+
#include "openvino/op/greater_eq.hpp"
1718
#include "openvino/op/multiply.hpp"
1819
#include "openvino/op/paged_attention.hpp"
1920
#include "openvino/op/parameter.hpp"
21+
#include "openvino/op/range.hpp"
2022
#include "openvino/op/reshape.hpp"
2123
#include "openvino/op/scaled_dot_product_attention.hpp"
2224
#include "openvino/op/select.hpp"
@@ -28,6 +30,7 @@
2830
#include "openvino/op/transpose.hpp"
2931
#include "openvino/op/unsqueeze.hpp"
3032
#include "openvino/op/variadic_split.hpp"
33+
#include "openvino/pass/pattern/op/optional.hpp"
3134
#include "openvino/pass/pattern/op/or.hpp"
3235
#include "openvino/pass/pattern/op/wrap_type.hpp"
3336
#include "transformations/utils/utils.hpp"
@@ -173,6 +176,27 @@ static std::shared_ptr<ov::Node> handle_baichuan2_13b_alibi(
173176
return res_alibi_slopes;
174177
}
175178

179+
static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> handle_phi3_sliding_window() {
180+
using namespace ov::pass::pattern;
181+
182+
auto offset = wrap_type<v0::Constant>();
183+
auto t196 = wrap_type<v1::Add>({any_input(), offset});
184+
auto t197 = pattern::optional<v0::Convert>(t196);
185+
auto t200 = pattern::wrap_type<v4::Range>({t197, any_input(), any_input()});
186+
auto t201 = pattern::wrap_type<v0::Unsqueeze>({t200, any_input()});
187+
auto t202 = pattern::wrap_type<v1::GreaterEqual>({any_input(), t201});
188+
auto t208 = pattern::wrap_type<v1::Select>({t202, any_input(), any_input()});
189+
auto t209 = pattern::wrap_type<v1::Subtract>({any_input(), t208});
190+
auto t210 = pattern::optional<v0::Convert>(t209);
191+
auto t211 = pattern::wrap_type<v1::Select>({t210, any_input(), any_input()});
192+
auto t213 = pattern::wrap_type<v0::Unsqueeze>({t211, any_input()});
193+
auto t214 = pattern::wrap_type<v0::Unsqueeze>({t213, any_input()});
194+
auto t218 = pattern::wrap_type<v3::Broadcast>({t214, any_input()});
195+
auto t219 = pattern::wrap_type<v1::Select>({any_input(), any_input(), t218});
196+
auto mask = pattern::wrap_type<v8::Slice>({t219, any_input(), any_input(), any_input(), any_input()});
197+
return {mask, offset};
198+
}
199+
176200
// Exactly copied the function from another file. Maybe should be moved to some general file
177201
static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> node, const std::string& name) {
178202
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
@@ -207,7 +231,6 @@ static node_tuple kv_read_and_concat(ov::Output<ov::Node> kv_current) {
207231

208232
ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_parameters,
209233
ParameterVector& model_remaining_params,
210-
const std::shared_ptr<ov::op::v0::Constant>& sliding_window,
211234
ParameterVector& parameters_to_remove,
212235
int& layer_index,
213236
Output<Node> max_context_len,
@@ -297,15 +320,20 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
297320
std::shared_ptr<ov::Node> baichuan2_13b_alibi, baichuan2_13b_alibi_mask;
298321
std::tie(baichuan2_13b_alibi, baichuan2_13b_alibi_mask) = baichuan2_13b_alibi_pattern();
299322

323+
// Phi3-xxx-instruct case
324+
std::shared_ptr<ov::Node> phi3_mask, phi3_offset;
325+
std::tie(phi3_mask, phi3_offset) = handle_phi3_sliding_window();
326+
300327
auto q = pattern::any_input();
301328
auto scale_input = pattern::any_input();
302329

303330
auto k_to_sdpa =
304331
std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
305332
auto v_to_sdpa =
306333
std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});
334+
307335
auto mask_to_sdpa = std::make_shared<pattern::op::Or>(
308-
OutputVector{general_alibi_mask, jais_alibi_mask, baichuan2_13b_alibi_mask, pattern::any_input()});
336+
OutputVector{phi3_mask, general_alibi_mask, jais_alibi_mask, baichuan2_13b_alibi_mask, pattern::any_input()});
309337

310338
auto sdpa_with_4_inputs =
311339
pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
@@ -317,7 +345,6 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
317345
ov::matcher_pass_callback callback = [=,
318346
&kv_parameters,
319347
&model_remaining_params,
320-
&sliding_window,
321348
&parameters_to_remove,
322349
&block_indices_inputs_for_each_layer,
323350
&score_results,
@@ -492,6 +519,18 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
492519

493520
OutputVector pa_arguments = {q_reshape, k_reshape, v_reshape, k_parameter, v_parameter};
494521
pa_arguments.insert(pa_arguments.end(), model_remaining_params.begin(), model_remaining_params.end());
522+
523+
std::shared_ptr<Node> sliding_window;
524+
if (pattern_map.count(phi3_offset)) {
525+
auto offset = pattern_map.at(phi3_offset).get_node_shared_ptr();
526+
if (offset->get_element_type() != element::i32) {
527+
offset = std::make_shared<v0::Convert>(offset, element::i32);
528+
}
529+
sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
530+
} else {
531+
sliding_window = v0::Constant::create(element::i32, Shape{}, {0});
532+
}
533+
495534
std::initializer_list<std::shared_ptr<Node>> additional_params = {scale,
496535
sliding_window,
497536
alibi_slopes,

0 commit comments

Comments
 (0)