Skip to content

ShapeInferencePass: AssertionError on unloaded initializers + function body shapes not updated #379

@justinchuby

Description

@justinchuby

Summary

Two bugs discovered in onnx_ir.passes.common.ShapeInferencePass (onnx-ir 0.2.0) while building the onnxruntime/mobius model export pipeline.

Filed by AI agent team from onnxruntime/mobius (comparative analysis in mobius PR #70).


Bug 1: AssertionError when initializers have no const_value

Description

ShapeInferencePass fails with an AssertionError when any graph initializer has const_value=None. This is the standard state for models that have been graph-built but not yet weight-loaded — a common pattern where shape inference is needed before weights are available.

The assert is in _c_api_utils.call_onnx_api (line 46):

assert initializer.const_value is not None

Direct model serialization (ir.serde.serialize_model) handles const_value=None gracefully (emits a warning). The _c_api_utils wrapper is stricter than necessary.

Minimal Reproduction

import onnx_ir as ir
from onnx_ir.passes import common as common_passes

float_type = ir.TensorType(ir.DataType.FLOAT)

# Input
X = ir.Value(name="X", shape=ir.Shape([2, 3]), type=float_type)

# Initializer (weight) — no const_value: weights not loaded yet
W = ir.Value(name="W", shape=ir.Shape([3, 4]), type=float_type)

# Node: Y = MatMul(X, W)
node = ir.Node("", "MatMul", inputs=[X, W], num_outputs=1)
Y = node.outputs[0]
Y.name = "Y"

graph = ir.Graph(
    inputs=[X],
    outputs=[Y],
    nodes=[node],
    initializers=[W],
    name="test_graph",
    opset_imports={"": 20},
)
model = ir.Model(graph, ir_version=10)

print(f"W.const_value = {W.const_value}")  # None
print(f"Y.type = {Y.type}")  # None — to be inferred

result = common_passes.ShapeInferencePass()(model)
# Expected: modified=True, Y.type=FLOAT
# Actual:   modified=False due to AssertionError in _c_api_utils.call_onnx_api:46
print(f"modified={result.modified}, Y.type={Y.type}")

Expected Behavior

ShapeInferencePass should skip or treat empty initializers as their declared type/shape for the purpose of calling onnx.shape_inference.infer_shapes. The initializer already has shape and type set — the const_value is not needed for shape inference to succeed.

Actual Behavior

WARNING  onnx_ir.passes.common.shape_inference:shape_inference.py:85 Shape inference failed: %s. Model is left unchanged
Traceback (most recent call last):
  File ".../onnx_ir/passes/common/_c_api_utils.py", line 46, in call_onnx_api
    assert initializer.const_value is not None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
modified=False, Y.type=None

Suggested Fix

In _c_api_utils.call_onnx_api, replace the assertion with a conditional skip:

# Before:
assert initializer.const_value is not None

# After (skip initializers without data — shape/type are enough for inference):
if initializer.const_value is None:
    continue

The initializer should still be added to model.graph.inputs with its declared type/shape so the ONNX C++ inference engine can use the type information.


Bug 2: Function body values not updated after shape inference

Description

When a model has local functions (ir.Function), ShapeInferencePass correctly infers output types at call sites in the main graph. However, the function body values (intermediate values inside the function) are not updated — they remain None after the pass runs.

Root cause: _merge_func iterates over model.graphs(), which returns only the main graph. Local function bodies are not included. The ONNX C++ inference engine does infer shapes inside function bodies (they appear in the serialized FunctionProto), but these inferred shapes are never merged back into the ir.Function.graph.

Minimal Reproduction

import onnx_ir as ir
from onnx_ir.passes import common as common_passes

float_type = ir.TensorType(ir.DataType.FLOAT)

# --- Local function: local::MyRelu(X: float) -> Y ---
func_input = ir.Value(name="X", type=float_type)
func_node = ir.Node("", "Relu", inputs=[func_input], num_outputs=1)
func_output = func_node.outputs[0]
func_output.name = "Y"
# func_output has no type set

func_graph = ir.Graph(
    inputs=[func_input],
    outputs=[func_output],
    nodes=[func_node],
    name="MyRelu_body",
    opset_imports={"": 20},
)
func = ir.Function(domain="local", name="MyRelu", graph=func_graph, attributes={})

# --- Main graph: Z = local::MyRelu(X) ---
X = ir.Value(name="X", shape=ir.Shape([2, 3]), type=float_type)
call_node = ir.Node("local", "MyRelu", inputs=[X], num_outputs=1)
Z = call_node.outputs[0]
Z.name = "Z"

main_graph = ir.Graph(
    inputs=[X],
    outputs=[Z],
    nodes=[call_node],
    name="main",
    opset_imports={"": 20, "local": 1},
)
model = ir.Model(main_graph, ir_version=10, functions=[func])

print(f"func_output.type before: {func_output.type}")  # None
print(f"Z.type before: {Z.type}")  # None

result = common_passes.ShapeInferencePass()(model)

print(f"modified={result.modified}")
print(f"Z.type after: {Z.type}")           # FLOAT ✅ (call site inferred)
print(f"func_output.type after: {func_output.type}")  # None ❌ (body not updated)

Expected Behavior

After ShapeInferencePass runs, values inside function bodies should have their types/shapes updated, just like main graph values.

Actual Behavior

modified=True
Z.type after: FLOAT          ✅ call-site output correctly inferred
func_output.type after: None  ❌ function body value NOT updated

The _merge_func function only iterates model.graphs(), which returns the main graph only:

for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
    # model.graphs() = [main_graph] — function bodies excluded
    ...

Suggested Fix

Extend _merge_func to also merge function body shapes. The ONNX protobuf has the inferred shapes in inferred_proto.functions[i].node[j].output[k] / function value_info. These should be merged back into the corresponding ir.Function.graph values.


Environment

onnx-ir:  0.2.0
onnx:     1.20.1 (opset 25)
Python:   3.13
Platform: Linux

Impact

Both bugs affect the mobius model export pipeline, which builds ONNX graphs without weights and relies on shape inference during the optimization phase. Bug 1 causes ShapeInferencePass to be completely unusable on any model with unloaded weights (which is the standard pre-weight-loading state). Bug 2 means function body shapes cannot be inspected or used by downstream passes.

We worked around both issues by using our own SymbolicShapeInferencePass (wrapping onnx-shape-inference), but the fixes to ShapeInferencePass would make it usable as a general-purpose pass for models at any stage of the build pipeline.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions