Skip to content

Commit 7ae606f

Browse files
authored
Qnn weight sharing improvement (#23945)
### Description Qnn weight sharing improvement so that only the last session in the weight sharing group (the session that has both share_ep_contexts and stop_share_ep_contexts enabled) generates the .bin file. The .bin file name is decided from the 1st session. And all generated *_ctx.onnx models point to this single .bin to avoid post-processing work. Previously each session generates a _ctx.onnx model with a .bin file. So it requires post-processing work to go through generated *_ctx.onnx models to get the last generated *_ctx.bin file and update all *_ctx.onnx to point to the same .bin file and remove the .bin files not used.
1 parent 57ddd02 commit 7ae606f

File tree

10 files changed

+158
-257
lines changed

10 files changed

+158
-257
lines changed

include/onnxruntime/core/session/onnxruntime_c_api.h

-3
Original file line numberDiff line numberDiff line change
@@ -3674,9 +3674,6 @@ struct OrtApi {
36743674
* Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
36753675
* - "0": With fp32 precision.
36763676
* - "1": Default. With fp16 precision.
3677-
* "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context.
3678-
* - "0": Default. Disabled.
3679-
* - "1": Enabled.
36803677
* "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another
36813678
* execution provider (typically CPU EP).
36823679
* - "0": Disabled. QNN EP will handle quantization and dequantization of graph I/O.

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

+33-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/providers/qnn/ort_api.h"
1111
#include "core/providers/qnn/builder/qnn_utils.h"
1212
#include "core/providers/qnn/builder/qnn_model.h"
13+
#include "core/providers/qnn/shared_context.h"
1314

1415
namespace onnxruntime {
1516
namespace qnn {
@@ -207,7 +208,9 @@ Status CreateEPContextNodes(Model* model,
207208
const onnxruntime::PathString& context_model_path,
208209
bool qnn_context_embed_mode,
209210
uint64_t max_spill_fill_buffer_size,
210-
const logging::Logger& logger) {
211+
const logging::Logger& logger,
212+
bool share_ep_contexts,
213+
bool stop_share_ep_contexts) {
211214
auto& graph = model->MainGraph();
212215

213216
using namespace ONNX_NAMESPACE;
@@ -241,6 +244,7 @@ Status CreateEPContextNodes(Model* model,
241244
ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload);
242245
} else {
243246
onnxruntime::PathString context_bin_path;
247+
std::string context_cache_name;
244248
auto pos = context_model_path.find_last_of(ORT_TSTR("."));
245249
if (pos != std::string::npos) {
246250
context_bin_path = context_model_path.substr(0, pos);
@@ -253,14 +257,36 @@ Status CreateEPContextNodes(Model* model,
253257
graph_name_in_file.replace(name_pos, strlen(kQnnExecutionProvider), "");
254258
}
255259
context_bin_path = context_bin_path + ToPathString(graph_name_in_file + ".bin");
256-
std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string());
257-
std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary);
258-
if (!of_stream) {
259-
LOGS(logger, ERROR) << "Failed to open create context file.";
260-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
260+
context_cache_name = std::filesystem::path(context_bin_path).filename().string();
261+
262+
// If generate ctx.onnx with share_ep_context enabled, all ctx.onnx should point to the same ctx.bin
263+
if (share_ep_contexts) {
264+
auto shared_ctx_bin_name = SharedContext::GetInstance().GetSharedCtxBinFileName();
265+
if (shared_ctx_bin_name.empty()) {
266+
SharedContext::GetInstance().SetSharedCtxBinFileName(context_cache_name);
267+
} else {
268+
context_cache_name = shared_ctx_bin_name;
269+
auto model_folder_path = std::filesystem::path(context_bin_path).parent_path().string();
270+
context_bin_path = ToPathString(model_folder_path + "/" + context_cache_name);
271+
}
272+
}
273+
274+
// Write the ctx.bin file for the case: 1. no share_ep_context enabled, write for every session
275+
// 2. share_ep_context enabled, only write for the last session which has stop_share_ep_contexts enabled
276+
if (!share_ep_contexts || (share_ep_contexts && stop_share_ep_contexts)) {
277+
std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary);
278+
if (!of_stream) {
279+
LOGS(logger, ERROR) << "Failed to open create context file.";
280+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
281+
}
282+
of_stream.write(reinterpret_cast<char*>(buffer), buffer_size);
261283
}
262-
of_stream.write(reinterpret_cast<char*>(buffer), buffer_size);
284+
263285
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
286+
if (share_ep_contexts && stop_share_ep_contexts) {
287+
SharedContext::GetInstance().ResetSharedCtxBinFileName();
288+
}
289+
264290
ep_node.AddAttribute(MAX_SIZE, static_cast<int64_t>(max_spill_fill_buffer_size));
265291
}
266292
} else {

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ Status CreateEPContextNodes(Model* model,
6565
const onnxruntime::PathString& context_model_path,
6666
bool qnn_context_embed_mode,
6767
uint64_t max_spill_fill_buffer_size,
68-
const logging::Logger& logger);
68+
const logging::Logger& logger,
69+
bool share_ep_contexts,
70+
bool stop_share_ep_contexts);
6971
} // namespace qnn
7072
} // namespace onnxruntime

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

+14-4
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t
538538
return Status::OK();
539539
}
540540

541-
Status QnnBackendManager::CreateContext() {
541+
Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) {
542542
if (true == context_created_) {
543543
LOGS_DEFAULT(INFO) << "Context created already.";
544544
return Status::OK();
@@ -547,7 +547,7 @@ Status QnnBackendManager::CreateContext() {
547547
QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
548548
QnnHtpContext_CustomConfig_t custom_config;
549549
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
550-
custom_config.weightSharingEnabled = enable_htp_weight_sharing_;
550+
custom_config.weightSharingEnabled = enable_htp_weight_sharing;
551551
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
552552
context_config_weight_sharing.customConfig = &custom_config;
553553

@@ -810,7 +810,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
810810
// or generate Qnn context binary is enabled -- to get the max spill fill buffer size
811811
Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
812812
bool load_from_cached_context,
813-
bool need_load_system_lib) {
813+
bool need_load_system_lib,
814+
bool share_ep_contexts) {
814815
std::lock_guard<std::recursive_mutex> lock(logger_recursive_mutex_);
815816
if (backend_setup_completed_) {
816817
LOGS(logger, VERBOSE) << "Backend setup already!";
@@ -865,9 +866,18 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
865866
LOGS(logger, VERBOSE) << "InitializeProfiling succeed.";
866867
}
867868

869+
bool enable_htp_weight_sharing = false;
870+
if (share_ep_contexts && !load_from_cached_context) {
871+
#if defined(__aarch64__) || defined(_M_ARM64)
872+
LOGS(logger, WARNING) << "Weight sharing only available with offline generation on x64 platform, not work on real device.";
873+
#else
874+
enable_htp_weight_sharing = true;
875+
#endif
876+
}
877+
868878
if (!load_from_cached_context) {
869879
if (status.IsOK()) {
870-
status = CreateContext();
880+
status = CreateContext(enable_htp_weight_sharing);
871881
}
872882
if (status.IsOK()) {
873883
LOGS(logger, VERBOSE) << "CreateContext succeed.";

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h

+4-6
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ struct QnnBackendManagerConfig {
4343
uint32_t device_id;
4444
QnnHtpDevice_Arch_t htp_arch;
4545
uint32_t soc_model;
46-
bool enable_htp_weight_sharing;
4746
};
4847

4948
class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager> {
@@ -67,8 +66,7 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
6766
qnn_saver_path_(config.qnn_saver_path),
6867
device_id_(config.device_id),
6968
htp_arch_(config.htp_arch),
70-
soc_model_(config.soc_model),
71-
enable_htp_weight_sharing_(config.enable_htp_weight_sharing) {
69+
soc_model_(config.soc_model) {
7270
}
7371

7472
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager);
@@ -84,7 +82,8 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
8482

8583
// Initializes handles to QNN resources (device, logger, etc.).
8684
// NOTE: This function locks the internal `logger_recursive_mutex_`.
87-
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib);
85+
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context,
86+
bool need_load_system_lib, bool share_ep_contexts);
8887

8988
Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
9089

@@ -155,7 +154,7 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
155154

156155
Status ReleaseProfilehandle();
157156

158-
Status CreateContext();
157+
Status CreateContext(bool enable_htp_weight_sharing);
159158

160159
Status ReleaseContext();
161160

@@ -298,7 +297,6 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
298297
uint32_t device_id_ = 0;
299298
QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE;
300299
uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN;
301-
bool enable_htp_weight_sharing_ = false;
302300
};
303301

304302
} // namespace qnn

onnxruntime/core/providers/qnn/qnn_execution_provider.cc

+9-17
Original file line numberDiff line numberDiff line change
@@ -337,19 +337,8 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
337337
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
338338
}
339339

340-
bool enable_htp_weight_sharing = false;
341-
static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing";
342-
auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED);
343-
if (htp_weight_sharing_enabled_pos != provider_options_map.end()) {
344-
if ("1" == htp_weight_sharing_enabled_pos->second) {
345-
enable_htp_weight_sharing = true;
346-
} else if ("0" == htp_weight_sharing_enabled_pos->second) {
347-
enable_htp_weight_sharing = false;
348-
} else {
349-
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing
350-
<< " only 0 or 1 allowed. Set to 0.";
351-
}
352-
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing;
340+
if (qnn_context_embed_mode_ && share_ep_contexts_) {
341+
LOGS_DEFAULT(ERROR) << "[EP context generation:] Weight sharing enabled conflict with EP context embed mode. Inference will not work as expected!";
353342
}
354343

355344
// Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform
@@ -406,8 +395,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
406395
qnn_saver_path,
407396
device_id_,
408397
htp_arch,
409-
soc_model,
410-
enable_htp_weight_sharing});
398+
soc_model});
411399
}
412400

413401
#if defined(_WIN32)
@@ -701,7 +689,9 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
701689
// It will load the QnnSystem lib if is_qnn_ctx_model=true, and
702690
// delay the Qnn context creation to Compile() using the cached context binary
703691
// or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size
704-
auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_);
692+
auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model,
693+
context_cache_enabled_ && enable_spill_fill_buffer_,
694+
share_ep_contexts_);
705695
if (Status::OK() != rt) {
706696
LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage();
707697
return result;
@@ -1051,7 +1041,9 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
10511041
context_model_path,
10521042
qnn_context_embed_mode_,
10531043
max_spill_fill_buffer_size,
1054-
logger));
1044+
logger,
1045+
share_ep_contexts_,
1046+
stop_share_ep_contexts_));
10551047

10561048
if (share_ep_contexts_ && !stop_share_ep_contexts_ &&
10571049
nullptr == SharedContext::GetInstance().GetSharedQnnBackendManager()) {

onnxruntime/core/providers/qnn/shared_context.h

+18
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,21 @@ class SharedContext {
8484
qnn_backend_manager_.reset();
8585
}
8686

87+
void SetSharedCtxBinFileName(std::string& shared_ctx_bin_file_name) {
88+
const std::lock_guard<std::mutex> lock(mtx_);
89+
shared_ctx_bin_file_name_ = shared_ctx_bin_file_name;
90+
}
91+
92+
const std::string& GetSharedCtxBinFileName() {
93+
const std::lock_guard<std::mutex> lock(mtx_);
94+
return shared_ctx_bin_file_name_;
95+
}
96+
97+
void ResetSharedCtxBinFileName() {
98+
const std::lock_guard<std::mutex> lock(mtx_);
99+
shared_ctx_bin_file_name_.clear();
100+
}
101+
87102
private:
88103
SharedContext() = default;
89104
~SharedContext() = default;
@@ -94,6 +109,9 @@ class SharedContext {
94109
std::vector<std::unique_ptr<qnn::QnnModel>> shared_qnn_models_;
95110
// Used for compiling multiple models into same QNN context binary
96111
std::shared_ptr<qnn::QnnBackendManager> qnn_backend_manager_;
112+
// Track the shared ctx binary .bin file name, all _ctx.onnx point to this .bin file
113+
// only the last session generate the .bin file since it contains all graphs from all sessions.
114+
std::string shared_ctx_bin_file_name_;
97115
// Producer sessions can be in parallel
98116
// Consumer sessions have to be after producer sessions initialized
99117
std::mutex mtx_;

onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc

+3-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ namespace qnnctxgen {
4848
"\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n"
4949
"\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n"
5050
"\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n"
51-
"\t [QNN only] [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n"
5251
"\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n"
5352
"\t Defaults to '1' (QNN EP handles the graph I/O quantization and dequantization). \n"
5453
"\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary."
@@ -161,8 +160,8 @@ static bool ParseSessionConfigs(const std::string& configs_string,
161160
std::string str = str_stream.str();
162161
ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str);
163162
}
164-
} else if (key == "enable_htp_fp16_precision" || key == "enable_htp_weight_sharing" ||
165-
key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer") {
163+
} else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" ||
164+
key == "enable_htp_spill_fill_buffer") {
166165
std::unordered_set<std::string> supported_options = {"0", "1"};
167166
if (supported_options.find(value) == supported_options.end()) {
168167
std::ostringstream str_stream;
@@ -173,7 +172,7 @@ static bool ParseSessionConfigs(const std::string& configs_string,
173172
}
174173
} else {
175174
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'vtcm_mb', 'htp_performance_mode',
176-
'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision', 'enable_htp_weight_sharing',
175+
'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', 'enable_htp_fp16_precision',
177176
'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])");
178177
}
179178

0 commit comments

Comments
 (0)