-
Notifications
You must be signed in to change notification settings - Fork 21
ShapeInferencePass: AssertionError on unloaded initializers + function body shapes not updated #379
Description
Summary
Two bugs discovered in onnx_ir.passes.common.ShapeInferencePass (onnx-ir 0.2.0) while building the onnxruntime/mobius model export pipeline.
Filed by AI agent team from onnxruntime/mobius (comparative analysis in mobius PR #70).
Bug 1: AssertionError when initializers have no const_value
Description
ShapeInferencePass fails with an AssertionError when any graph initializer has const_value=None. This is the standard state for models that have been graph-built but not yet weight-loaded — a common pattern where shape inference is needed before weights are available.
The assert is in _c_api_utils.call_onnx_api (line 46):
assert initializer.const_value is not NoneDirect model serialization (ir.serde.serialize_model) handles const_value=None gracefully (emits a warning). The _c_api_utils wrapper is stricter than necessary.
Minimal Reproduction
import onnx_ir as ir
from onnx_ir.passes import common as common_passes
float_type = ir.TensorType(ir.DataType.FLOAT)
# Input
X = ir.Value(name="X", shape=ir.Shape([2, 3]), type=float_type)
# Initializer (weight) — no const_value: weights not loaded yet
W = ir.Value(name="W", shape=ir.Shape([3, 4]), type=float_type)
# Node: Y = MatMul(X, W)
node = ir.Node("", "MatMul", inputs=[X, W], num_outputs=1)
Y = node.outputs[0]
Y.name = "Y"
graph = ir.Graph(
inputs=[X],
outputs=[Y],
nodes=[node],
initializers=[W],
name="test_graph",
opset_imports={"": 20},
)
model = ir.Model(graph, ir_version=10)
print(f"W.const_value = {W.const_value}") # None
print(f"Y.type = {Y.type}") # None — to be inferred
result = common_passes.ShapeInferencePass()(model)
# Expected: modified=True, Y.type=FLOAT
# Actual: modified=False due to AssertionError in _c_api_utils.call_onnx_api:46
print(f"modified={result.modified}, Y.type={Y.type}")Expected Behavior
ShapeInferencePass should skip or treat empty initializers as their declared type/shape for the purpose of calling onnx.shape_inference.infer_shapes. The initializer already has shape and type set — the const_value is not needed for shape inference to succeed.
Actual Behavior
WARNING onnx_ir.passes.common.shape_inference:shape_inference.py:85 Shape inference failed: %s. Model is left unchanged
Traceback (most recent call last):
File ".../onnx_ir/passes/common/_c_api_utils.py", line 46, in call_onnx_api
assert initializer.const_value is not None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
modified=False, Y.type=None
Suggested Fix
In _c_api_utils.call_onnx_api, replace the assertion with a conditional skip:
# Before:
assert initializer.const_value is not None
# After (skip initializers without data — shape/type are enough for inference):
if initializer.const_value is None:
continueThe initializer should still be added to model.graph.inputs with its declared type/shape so the ONNX C++ inference engine can use the type information.
Bug 2: Function body values not updated after shape inference
Description
When a model has local functions (ir.Function), ShapeInferencePass correctly infers output types at call sites in the main graph. However, the function body values (intermediate values inside the function) are not updated — they remain None after the pass runs.
Root cause: _merge_func iterates over model.graphs(), which returns only the main graph. Local function bodies are not included. The ONNX C++ inference engine does infer shapes inside function bodies (they appear in the serialized FunctionProto), but these inferred shapes are never merged back into the ir.Function.graph.
Minimal Reproduction
import onnx_ir as ir
from onnx_ir.passes import common as common_passes
float_type = ir.TensorType(ir.DataType.FLOAT)
# --- Local function: local::MyRelu(X: float) -> Y ---
func_input = ir.Value(name="X", type=float_type)
func_node = ir.Node("", "Relu", inputs=[func_input], num_outputs=1)
func_output = func_node.outputs[0]
func_output.name = "Y"
# func_output has no type set
func_graph = ir.Graph(
inputs=[func_input],
outputs=[func_output],
nodes=[func_node],
name="MyRelu_body",
opset_imports={"": 20},
)
func = ir.Function(domain="local", name="MyRelu", graph=func_graph, attributes={})
# --- Main graph: Z = local::MyRelu(X) ---
X = ir.Value(name="X", shape=ir.Shape([2, 3]), type=float_type)
call_node = ir.Node("local", "MyRelu", inputs=[X], num_outputs=1)
Z = call_node.outputs[0]
Z.name = "Z"
main_graph = ir.Graph(
inputs=[X],
outputs=[Z],
nodes=[call_node],
name="main",
opset_imports={"": 20, "local": 1},
)
model = ir.Model(main_graph, ir_version=10, functions=[func])
print(f"func_output.type before: {func_output.type}") # None
print(f"Z.type before: {Z.type}") # None
result = common_passes.ShapeInferencePass()(model)
print(f"modified={result.modified}")
print(f"Z.type after: {Z.type}") # FLOAT ✅ (call site inferred)
print(f"func_output.type after: {func_output.type}") # None ❌ (body not updated)Expected Behavior
After ShapeInferencePass runs, values inside function bodies should have their types/shapes updated, just like main graph values.
Actual Behavior
modified=True
Z.type after: FLOAT ✅ call-site output correctly inferred
func_output.type after: None ❌ function body value NOT updated
The _merge_func function only iterates model.graphs(), which returns the main graph only:
for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
# model.graphs() = [main_graph] — function bodies excluded
...Suggested Fix
Extend _merge_func to also merge function body shapes. The ONNX protobuf has the inferred shapes in inferred_proto.functions[i].node[j].output[k] / function value_info. These should be merged back into the corresponding ir.Function.graph values.
Environment
onnx-ir: 0.2.0
onnx: 1.20.1 (opset 25)
Python: 3.13
Platform: Linux
Impact
Both bugs affect the mobius model export pipeline, which builds ONNX graphs without weights and relies on shape inference during the optimization phase. Bug 1 causes ShapeInferencePass to be completely unusable on any model with unloaded weights (which is the standard pre-weight-loading state). Bug 2 means function body shapes cannot be inspected or used by downstream passes.
We worked around both issues by using our own SymbolicShapeInferencePass (wrapping onnx-shape-inference), but the fixes to ShapeInferencePass would make it usable as a general-purpose pass for models at any stage of the build pipeline.