|
| 1 | +// Copyright (C) 2018-2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | +#include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp" |
| 5 | + |
| 6 | +#include "common_test_utils/ov_test_utils.hpp" |
| 7 | +#include "openvino/opsets/opset1.hpp" |
| 8 | +#include "openvino/pass/manager.hpp" |
| 9 | +#include "ov_ops/rotary_positional_embeddings.hpp" |
| 10 | +#include "transformations/rt_info/disable_fp16_compression.hpp" |
| 11 | + |
| 12 | +TEST_F(TransformationTestsF, MarkRopeInputsToKeepInMixedPrecisionTest) { |
| 13 | + /* |
| 14 | + The 2nd/3rd inputs of ROPE is marked as FP32 |
| 15 | + Param2 Param3 |
| 16 | + \ / |
| 17 | + \ / |
| 18 | + Matmul(FP32) |
| 19 | + | |
| 20 | + Transpose(FP32) |
| 21 | + | |
| 22 | + Concat(FP32) |
| 23 | + / \ |
| 24 | + / \ |
| 25 | + Param1 Cos(FP32) Sin(FP32) |
| 26 | + \ | / |
| 27 | + \ | / |
| 28 | + \ | / |
| 29 | + ROPE |
| 30 | + */ |
| 31 | + { |
| 32 | + auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 10, 8, 64}); |
| 33 | + auto input_a = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 32, 1}); |
| 34 | + auto input_b = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, 10}); |
| 35 | + auto matmul = std::make_shared<ov::opset1::MatMul>(input_a, input_b); |
| 36 | + auto transpose_order = |
| 37 | + ov::op::v0::Constant::create(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{0, 2, 1}); |
| 38 | + auto transpose = std::make_shared<ov::opset1::Transpose>(matmul, transpose_order); |
| 39 | + auto concat = std::make_shared<ov::opset1::Concat>(ov::NodeVector{transpose, transpose}, -1); |
| 40 | + auto cos = std::make_shared<ov::opset1::Cos>(concat); |
| 41 | + auto sin = std::make_shared<ov::opset1::Sin>(concat); |
| 42 | + ov::op::internal::RoPE::Config config; |
| 43 | + auto rope = |
| 44 | + std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{input->output(0), cos->output(0), sin->output(0)}, |
| 45 | + config); |
| 46 | + model = std::make_shared<ov::Model>(rope, ov::ParameterVector{input, input_a, input_b}, "model"); |
| 47 | + } |
| 48 | + |
| 49 | + manager.register_pass<ov::pass::MarkRopeInputsToKeepInMixedPrecision>(); |
| 50 | + |
| 51 | + { |
| 52 | + auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 10, 8, 64}); |
| 53 | + auto input_a = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 32, 1}); |
| 54 | + auto input_b = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{1, 1, 10}); |
| 55 | + auto matmul = std::make_shared<ov::opset1::MatMul>(input_a, input_b); |
| 56 | + auto transpose_order = |
| 57 | + ov::op::v0::Constant::create(ov::element::i32, ov::Shape{3}, std::vector<int32_t>{0, 2, 1}); |
| 58 | + auto transpose = std::make_shared<ov::opset1::Transpose>(matmul, transpose_order); |
| 59 | + auto concat = std::make_shared<ov::opset1::Concat>(ov::NodeVector{transpose, transpose}, -1); |
| 60 | + auto cos = std::make_shared<ov::opset1::Cos>(concat); |
| 61 | + auto sin = std::make_shared<ov::opset1::Sin>(concat); |
| 62 | + disable_fp16_compression(matmul); |
| 63 | + disable_fp16_compression(transpose); |
| 64 | + disable_fp16_compression(concat); |
| 65 | + disable_fp16_compression(cos); |
| 66 | + disable_fp16_compression(sin); |
| 67 | + ov::op::internal::RoPE::Config config; |
| 68 | + auto rope = |
| 69 | + std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{input->output(0), cos->output(0), sin->output(0)}, |
| 70 | + config); |
| 71 | + model_ref = std::make_shared<ov::Model>(rope, ov::ParameterVector{input, input_a, input_b}, "model_ref"); |
| 72 | + } |
| 73 | + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); |
| 74 | + comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); |
| 75 | +} |
0 commit comments