Skip to content

Commit e8b0bdb

Browse files
authored
Shape inference: ReduceMean dispatcher, quant_pre_process: skip_symbolic_shape bugfix (microsoft#23558)
### Description - Add symbolic shape inference dispatcher for `ReduceMean`. - Reducemean is used in RMSNorm so shape inference fails for llama, phi, etc torch exported models. - Reuse the dispatcher for ReduceSum since ReduceMean 18+ and ReduceSum 13+ have the same specs other than the type of reduction done. - Fix an issue with `quant_pre_process` tool where the external data file is missing if `skip_symbolic_shape=True` and `skip_optimization=False`. - Add `"session.optimized_model_external_initializers_file_name"` to session options so that the external data gets saved in the same temp directory as the optimized model. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 267b493 commit e8b0bdb

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

onnxruntime/python/tools/quantization/shape_inference.py

+6
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ def quant_pre_process(
119119
external_names, external_values = extract_raw_data_from_model(input_model)
120120
sess_option.add_external_initializers(list(external_names), list(external_values))
121121
input_model = input_model.SerializeToString()
122+
# the saved optimized model otherwise points to the original external data file name
123+
# which is not available relative to the optimized model file
124+
elif skip_symbolic_shape and save_as_external_data:
125+
sess_option.add_session_config_entry(
126+
"session.optimized_model_external_initializers_file_name", "optimized.onnx.data"
127+
)
122128

123129
sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
124130
# Close the session to avoid the cleanup error on Windows for temp folders

onnxruntime/python/tools/symbolic_shape_infer.py

+6
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
166166
"Range": self._infer_Range,
167167
"Reciprocal": self._pass_on_shape_and_type,
168168
"ReduceSum": self._infer_ReduceSum,
169+
"ReduceMean": self._infer_ReduceMean,
169170
"ReduceProd": self._infer_ReduceProd,
170171
"Reshape": self._infer_Reshape,
171172
"Resize": self._infer_Resize,
@@ -1603,6 +1604,11 @@ def _infer_ReduceSum(self, node): # noqa: N802
16031604
)
16041605
)
16051606

1607+
def _infer_ReduceMean(self, node): # noqa: N802
1608+
if get_opset(self.out_mp_) >= 18:
1609+
# reduce mean spec 18+ is same as reduce sum spec 13+
1610+
self._infer_ReduceSum(node)
1611+
16061612
def _infer_ReduceProd(self, node): # noqa: N802
16071613
axes = get_attribute(node, "axes")
16081614
keep_dims = get_attribute(node, "keepdims", 1)

0 commit comments

Comments
 (0)