Skip to content

Commit 93ee7bf

Browse files
authored
[QNN EP] MatMul+Add->Gemm fusion when AttentionFusion isn't enabled (#25017)
### Description MatMul+Add->Gemm fusion when AttentionFusion isn't enabled. ### Motivation and Context Graph transformation [MatMulAddFusion](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/optimizer/matmul_add_fusion.cc) fold `ONNX::MatMul` followed by `ONNX::Add` into `ONNX::GEMM`, however, it [intentionally skipping the portion belongs to "Attention Pattern"](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/optimizer/matmul_add_fusion.cc#L21). This result in poor performance on QNN EP (and other EPs who does not run *AttentionFusion transformers) due to unfused MatMul + Add pairs. ![image](https://github.com/user-attachments/assets/cad0b2c6-ab07-4ced-a647-396c04fed365) With this change, additional GEMM would be fused *post* AttentionFusions.
1 parent f3c18ed commit 93ee7bf

File tree

7 files changed

+64
-13
lines changed

7 files changed

+64
-13
lines changed

onnxruntime/core/optimizer/graph_transformer_mgr.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "core/optimizer/graph_transformer_mgr.h"
55
#include "core/optimizer/rule_based_graph_transformer.h"
66

7+
#include <memory>
8+
#include <utility>
9+
710
using namespace onnxruntime;
811
using namespace ::onnxruntime::common;
912

@@ -60,7 +63,8 @@ void GraphTransformerManager::ClearGraphModified(void) {
6063
common::Status GraphTransformerManager::Register(std::unique_ptr<GraphTransformer> transformer,
6164
TransformerLevel level) {
6265
const auto& name = transformer->Name();
63-
if (transformers_info_.find(name) != transformers_info_.end()) {
66+
const auto& registered = level_to_transformer_map_[level];
67+
if (std::find(registered.begin(), registered.end(), transformer) != registered.end()) {
6468
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer is already registered " + name);
6569
}
6670

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
215215
const InlinedHashSet<std::string_view> cpu_acl_eps = {onnxruntime::kCpuExecutionProvider,
216216
onnxruntime::kAclExecutionProvider};
217217
#endif
218+
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
218219
const InlinedHashSet<std::string_view> dml_ep = {onnxruntime::kDmlExecutionProvider};
219220
AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance();
220221

@@ -243,7 +244,6 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
243244
for (const auto& p : session_options.initializers_to_share_map) {
244245
excluded_initializers.insert(p.first);
245246
}
246-
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
247247
transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));
248248
transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
249249
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
@@ -363,14 +363,13 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
363363
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_acl_cuda_dml_rocm_eps));
364364
transformers.emplace_back(std::make_unique<GatherSliceToSplitFusion>(cpu_cuda_rocm_eps));
365365
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
366-
367366
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_dml_rocm_eps));
368367
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_acl_cuda_dml_rocm_eps));
369-
370368
transformers.emplace_back(std::make_unique<GroupQueryAttentionFusion>(cuda_eps));
371-
369+
// Run MatMulAddFusion again after *AttentionFusion transforms with `preserve_attention_pattern = false`,
370+
// to cleanup the remaining MatMul-Add that were part of the attention pattern but not detected or fused.
371+
transformers.emplace_back(std::make_unique<MatMulAddFusion>(no_limit_empty_ep_list, false));
372372
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_acl_cuda_dml_rocm_eps));
373-
374373
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_dml_rocm_eps));
375374
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_acl_cuda_dml_rocm_eps));
376375

onnxruntime/core/optimizer/matmul_add_fusion.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include "core/common/inlined_containers.h"
5+
#include "core/framework/tensorprotoutils.h"
6+
#include "core/graph/graph_utils.h"
7+
#include "core/optimizer/graph_transformer_utils.h"
48
#include "core/optimizer/initializer.h"
59
#include "core/optimizer/matmul_add_fusion.h"
6-
#include "core/graph/graph_utils.h"
7-
#include "core/framework/tensorprotoutils.h"
8-
#include <deque>
10+
11+
#include <string>
12+
#include <string_view>
13+
#include <unordered_set>
14+
#include <vector>
915

1016
using namespace ONNX_NAMESPACE;
1117
using namespace ::onnxruntime::common;
@@ -128,7 +134,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
128134
int64_t m = 0, k = 0, n = 0;
129135
if (need_reshape) {
130136
// Only check and skip Attention pattern here because normally input to Attention is 4D.
131-
if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) {
137+
if (preserve_attention_pattern_ && attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) {
132138
continue;
133139
}
134140

onnxruntime/core/optimizer/matmul_add_fusion.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@ namespace onnxruntime {
99

1010
class MatMulAddFusion : public GraphTransformer {
1111
public:
12-
MatMulAddFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
13-
: GraphTransformer("MatMulAddFusion", compatible_execution_providers) {}
12+
MatMulAddFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
13+
const bool preserve_attention_pattern = true) noexcept
14+
: GraphTransformer("MatMulAddFusion", compatible_execution_providers),
15+
preserve_attention_pattern_(preserve_attention_pattern) {}
1416

1517
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
18+
19+
private:
20+
bool preserve_attention_pattern_;
1621
};
1722

1823
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace qnn {
1212

1313
/**
1414
* An ONNX MatMul can be translated to either a QNN MatMul or a QNN FullyConnected.
15-
* ONNX's MatMul suports inputs of rank 1, but neither QNN's MatMul nor FullyConnected support two rank 1 inputs.
15+
* ONNX's MatMul supports inputs of rank 1, but neither QNN's MatMul nor FullyConnected support two rank 1 inputs.
1616
* So, we need to add Reshape Ops if necessary.
1717
* In two cases, FullyConnected (input_1's shape is [n, k]) is used instead of MatMul without extra Transpose Op:
1818
* 1. input_1 is a rank 2 initializer.

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,6 +2696,43 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_NeedReshape_3D) {
26962696
1, pre_graph_checker, post_graph_checker));
26972697
}
26982698

2699+
// With attention pattern, but targeting an execution provider that does not perform
2700+
// AttentionFusion, fuse into GEMM should still be happen, rather than skipping them
2701+
TEST_F(GraphTransformationTests, MatMulAddFusion_PreserveAttentionPattern) {
2702+
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/matmul_add_from_attention.onnx";
2703+
2704+
std::shared_ptr<Model> p_model;
2705+
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
2706+
2707+
// This toy model contains 11 MatMul + Add pairs, 0 GEMMs.
2708+
// 7 of them are out of "Attention Pattern" (see MatMulAddFusion::IsAttentionPattern)
2709+
// 4 of them are in "Attention Pattern" conditionally skipped by MatMulAddFusion pass
2710+
OpCountMap op_count_before = CountOpsInGraph(p_model->MainGraph());
2711+
const InlinedHashSet<std::string_view> empty_list = {};
2712+
2713+
// In attention pattern, 4 MatMul + Add pairs should be preserved
2714+
Graph& graph = p_model->MainGraph();
2715+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
2716+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
2717+
std::make_unique<MatMulAddFusion>(empty_list, /*preserve_attention_pattern=*/true), TransformerLevel::Level1));
2718+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
2719+
OpCountMap op_count_cpu_ep = CountOpsInGraph(graph);
2720+
constexpr int expected_fusions1 = 11 - 4;
2721+
ASSERT_EQ(op_count_cpu_ep["MatMul"], op_count_before["MatMul"] - expected_fusions1);
2722+
ASSERT_EQ(op_count_cpu_ep["Add"], op_count_before["Add"] - expected_fusions1);
2723+
ASSERT_EQ(op_count_cpu_ep["Gemm"], op_count_before["Gemm"] + expected_fusions1);
2724+
2725+
// In attention pattern, 4 MatMul + Add pairs should be fused into Gemm
2726+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
2727+
std::make_unique<MatMulAddFusion>(empty_list, /*preserve_attention_pattern=*/false), TransformerLevel::Level2));
2728+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
2729+
OpCountMap op_count_qnn_ep = CountOpsInGraph(graph);
2730+
constexpr int expected_fusions2 = 11;
2731+
ASSERT_EQ(op_count_qnn_ep["MatMul"], op_count_before["MatMul"] - expected_fusions2);
2732+
ASSERT_EQ(op_count_qnn_ep["Add"], op_count_before["Add"] - expected_fusions2);
2733+
ASSERT_EQ(op_count_qnn_ep["Gemm"], op_count_before["Gemm"] + expected_fusions2);
2734+
}
2735+
26992736
#ifndef DISABLE_CONTRIB_OPS
27002737
TEST_F(GraphTransformationTests, Gemm_Relu_three_input) {
27012738
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx";

0 commit comments

Comments
 (0)