Skip to content

Commit eda2f7f

Browse files
authored
[Core][CPU]markup rope's sin/cos generation with f32 (openvinotoolkit#25662)
### Details: - *Sin/Cos table generation must run in f32 otherwise it has accuracy issue* - Reference : huggingface/transformers#29285 ### Tickets: - *CVS-146672*
1 parent 45b4737 commit eda2f7f

File tree

6 files changed

+185
-8
lines changed

6 files changed

+185
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/pass/graph_rewrite.hpp"
8+
#include "transformations_visibility.hpp"
9+
10+
namespace ov {
11+
namespace pass {
12+
13+
/**
14+
* @ingroup ov_transformation_common_api
15+
* @brief This transformation markups the 2nd/3rd inputs of Rope with FP32 to mantian accuracy.
16+
* +-------+ +-------+ +-------+
17+
* |intput1| |input2 | |input3 |
18+
* |(orig) | |(fp32) | |(fp32) |
19+
* +---|---+ +---|---+ +---|---+
20+
* | | |
21+
* | | |
22+
* +--+------------|------------+--+
23+
* | |
24+
* | ROPE |
25+
* +-------------------------------+
26+
*/
27+
28+
class TRANSFORMATIONS_API MarkRopeInputsToKeepInMixedPrecision : public ov::pass::MatcherPass {
29+
public:
30+
OPENVINO_RTTI("MarkRopeInputsToKeepInMixedPrecision", "0");
31+
MarkRopeInputsToKeepInMixedPrecision();
32+
33+
private:
34+
std::unordered_set<ov::Node*> visited;
35+
};
36+
37+
} // namespace pass
38+
} // namespace ov

src/common/transformations/include/transformations/utils/utils.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,19 @@ TRANSFORMATIONS_API std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<N
215215
TRANSFORMATIONS_API bool shapes_equal_except_dynamic_expected_batch(const PartialShape& expected,
216216
const PartialShape& actual);
217217

218+
/**
219+
* \brief Traverses path starting from `node`, and calls "func" for each ov::Node.
220+
*
221+
* \param node The node from which path is started.
222+
* \param visited Set of nodes which were visited.
223+
* \param func The function which is called for each visited node.
224+
* \param skip_node_predicate predicte to skip nodes.
225+
*/
226+
TRANSFORMATIONS_API void visit_path(ov::Node* node,
227+
std::unordered_set<ov::Node*>& visited,
228+
std::function<void(ov::Node*)> func,
229+
std::function<bool(ov::Node*)> skip_node_predicate);
230+
218231
/**
219232
* \brief Traverses a shapeOf subgraph starting from the node and not including the ShapeOf nodes,
220233
* and calls "func" for each ov::Node.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp"
6+
7+
#include <unordered_set>
8+
9+
#include "itt.hpp"
10+
#include "openvino/core/rt_info.hpp"
11+
#include "openvino/pass/constant_folding.hpp"
12+
#include "openvino/pass/pattern/op/wrap_type.hpp"
13+
#include "ov_ops/rotary_positional_embeddings.hpp"
14+
#include "transformations/rt_info/disable_fp16_compression.hpp"
15+
#include "transformations/utils/gen_pattern.hpp"
16+
#include "transformations/utils/utils.hpp"
17+
18+
ov::pass::MarkRopeInputsToKeepInMixedPrecision::MarkRopeInputsToKeepInMixedPrecision() {
19+
MATCHER_SCOPE(MarkRopeInputsToKeepInMixedPrecision);
20+
using namespace ov::pass::pattern;
21+
using namespace ov::gen_pattern;
22+
auto cos_tab = any_input();
23+
auto sin_tab = any_input();
24+
auto rope = makePattern<ov::op::internal::RoPE>({any_input(), cos_tab, sin_tab});
25+
26+
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
27+
const auto& pattern_map = m.get_pattern_value_map();
28+
auto cos_input_node = pattern_map.at(cos_tab).get_node();
29+
auto sin_input_node = pattern_map.at(sin_tab).get_node();
30+
// mark the node as disable_fp16_compression
31+
auto visit_func = [](ov::Node* node) {
32+
ov::disable_fp16_compression(node->shared_from_this());
33+
};
34+
// skip constant and parameter node
35+
auto skip_node_predicate = [](ov::Node* node) -> bool {
36+
return ov::is_type<ov::op::v0::Constant>(node) || ov::is_type<ov::op::v0::Parameter>(node);
37+
};
38+
if (!visited.count(cos_input_node)) {
39+
ov::op::util::visit_path(cos_input_node, visited, visit_func, skip_node_predicate);
40+
}
41+
if (!visited.count(sin_input_node)) {
42+
ov::op::util::visit_path(sin_input_node, visited, visit_func, skip_node_predicate);
43+
}
44+
return false;
45+
};
46+
47+
auto m = std::make_shared<ov::pass::pattern::Matcher>(rope, matcher_name);
48+
this->register_matcher(m, callback);
49+
}

src/common/transformations/src/transformations/utils/utils.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@ namespace ov {
3030
namespace op {
3131
namespace util {
3232

33-
namespace {
34-
void visit_path_impl(ov::Node* node,
35-
std::unordered_set<ov::Node*>& visited,
36-
std::function<void(ov::Node*)> func,
37-
std::function<bool(ov::Node*)> skip_node_predicate) {
33+
void visit_path(ov::Node* node,
34+
std::unordered_set<ov::Node*>& visited,
35+
std::function<void(ov::Node*)> func,
36+
std::function<bool(ov::Node*)> skip_node_predicate) {
3837
if (!node)
3938
return;
4039
visited.insert(node);
@@ -56,7 +55,6 @@ void visit_path_impl(ov::Node* node,
5655
}
5756
}
5857
}
59-
} // namespace
6058

6159
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
6260
switch (const_node->get_element_type()) {
@@ -287,7 +285,7 @@ void visit_shape_path(Node* node, std::unordered_set<ov::Node*>& visited, std::f
287285
auto is_shapeof = [](ov::Node* node) {
288286
return ov::is_type<ov::op::v0::ShapeOf>(node) || ov::is_type<ov::op::v3::ShapeOf>(node);
289287
};
290-
visit_path_impl(node, visited, func, is_shapeof);
288+
visit_path(node, visited, func, is_shapeof);
291289
}
292290

293291
void visit_constant_path(ov::Node* node, std::unordered_set<ov::Node*>& visited, std::function<void(ov::Node*)> func) {
@@ -296,7 +294,7 @@ void visit_constant_path(ov::Node* node, std::unordered_set<ov::Node*>& visited,
296294
"visit_constant_path is called for non-constant path.");
297295
return false;
298296
};
299-
visit_path_impl(node, visited, func, check_parameter);
297+
visit_path(node, visited, func, check_parameter);
300298
}
301299

302300
bool is_dequantization_subgraph(const Output<Node>& node) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp"
5+
6+
#include "common_test_utils/ov_test_utils.hpp"
7+
#include "openvino/opsets/opset1.hpp"
8+
#include "openvino/pass/manager.hpp"
9+
#include "ov_ops/rotary_positional_embeddings.hpp"
10+
#include "transformations/rt_info/disable_fp16_compression.hpp"
11+
12+
TEST_F(TransformationTestsF, MarkRopeInputsToKeepInMixedPrecisionTest) {
13+
/*
14+
The 2nd/3rd inputs of ROPE is marked as FP32
15+
Param2 Param3
16+
\ /
17+
\ /
18+
Matmul(FP32)
19+
|
20+
Transpose(FP32)
21+
|
22+
Concat(FP32)
23+
/ \
24+
/ \
25+
Param1 Cos(FP32) Sin(FP32)
26+
\ | /
27+
\ | /
28+
\ | /
29+
ROPE
30+
*/
31+
{
32+
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 10, 8, 64});
33+
auto input_a = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 32, 1});
34+
auto input_b = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, 10});
35+
auto matmul = std::make_shared<ov::opset1::MatMul>(input_a, input_b);
36+
auto transpose_order =
37+
ov::op::v0::Constant::create(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{0, 2, 1});
38+
auto transpose = std::make_shared<ov::opset1::Transpose>(matmul, transpose_order);
39+
auto concat = std::make_shared<ov::opset1::Concat>(ov::NodeVector{transpose, transpose}, -1);
40+
auto cos = std::make_shared<ov::opset1::Cos>(concat);
41+
auto sin = std::make_shared<ov::opset1::Sin>(concat);
42+
ov::op::internal::RoPE::Config config;
43+
auto rope =
44+
std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{input->output(0), cos->output(0), sin->output(0)},
45+
config);
46+
model = std::make_shared<ov::Model>(rope, ov::ParameterVector{input, input_a, input_b}, "model");
47+
}
48+
49+
manager.register_pass<ov::pass::MarkRopeInputsToKeepInMixedPrecision>();
50+
51+
{
52+
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 10, 8, 64});
53+
auto input_a = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 32, 1});
54+
auto input_b = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, 10});
55+
auto matmul = std::make_shared<ov::opset1::MatMul>(input_a, input_b);
56+
auto transpose_order =
57+
ov::op::v0::Constant::create(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{0, 2, 1});
58+
auto transpose = std::make_shared<ov::opset1::Transpose>(matmul, transpose_order);
59+
auto concat = std::make_shared<ov::opset1::Concat>(ov::NodeVector{transpose, transpose}, -1);
60+
auto cos = std::make_shared<ov::opset1::Cos>(concat);
61+
auto sin = std::make_shared<ov::opset1::Sin>(concat);
62+
disable_fp16_compression(matmul);
63+
disable_fp16_compression(transpose);
64+
disable_fp16_compression(concat);
65+
disable_fp16_compression(cos);
66+
disable_fp16_compression(sin);
67+
ov::op::internal::RoPE::Config config;
68+
auto rope =
69+
std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{input->output(0), cos->output(0), sin->output(0)},
70+
config);
71+
model_ref = std::make_shared<ov::Model>(rope, ov::ParameterVector{input, input_a, input_b}, "model_ref");
72+
}
73+
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
74+
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
75+
}

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
3636
#include "transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp"
3737
#include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp"
38+
#include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp"
3839
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
3940
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
4041
#include "transformations/op_conversions/convert_avgpool_downgrade.hpp"
@@ -853,6 +854,9 @@ void Transformations::PostLpt() {
853854
}
854855
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::transpose_sinking::TSShapeOfForward);
855856
CPU_REGISTER_PASS_COMMON(postLPTPassManager, StatefulSDPAFusion);
857+
// markup Rope Input when BF16/F16 inference.
858+
if (one_of(inferencePrecision, ov::element::bf16, ov::element::f16))
859+
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::MarkRopeInputsToKeepInMixedPrecision);
856860

857861
// Should be before Snippets pipeline because Ngram pattern contains eltwise nodes that can be tokenized by Snippets.
858862
auto symbolic_pipeline = CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::SymbolicOptimizations, false);

0 commit comments

Comments
 (0)