Skip to content

Commit ec84831

Browse files
tianleiwuPrathik Rao
authored and
Prathik Rao
committed
Fix symbolic shape infer empty value_info (#15842)
### Description When node output is optional, symbolic shape infer might add an empty value_info item. Add some checking to avoid this. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> - Stable diffusion optimized model reported invalid data type 0 during inference.
1 parent 7eff605 commit ec84831

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

onnxruntime/python/tools/symbolic_shape_infer.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,13 @@ def _onnx_infer_single_node(self, node):
491491

492492
for i_o in range(len(node.output)):
493493
o = node.output[i_o]
494-
vi = self.out_mp_.graph.value_info.add()
495-
if not skip_infer:
496-
vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
497-
else:
498-
vi.name = o
499-
self.known_vi_[o] = vi
494+
if o: # skip optional output
495+
vi = self.out_mp_.graph.value_info.add()
496+
if not skip_infer:
497+
vi.CopyFrom(self.tmp_mp_.graph.output[i_o])
498+
else:
499+
vi.name = o
500+
self.known_vi_[o] = vi
500501

501502
def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True):
502503
if self.verbose_ > 2:

onnxruntime/python/tools/transformers/onnx_model.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -644,12 +644,17 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs):
644644
if model_with_shape is not None:
645645
name_vi = {}
646646
for vi in model_with_shape.graph.value_info:
647-
vi_copy = ValueInfoProto()
648-
vi_copy.CopyFrom(vi)
649-
if hasattr(vi_copy.type, "tensor_type") and hasattr(vi_copy.type.tensor_type, "shape"):
650-
vi_copy.type.tensor_type.ClearField("shape")
651-
name_vi[vi.name] = vi_copy
652-
647+
if (
648+
hasattr(vi.type, "tensor_type")
649+
and hasattr(vi.type.tensor_type, "elem_type")
650+
and vi.type.tensor_type.elem_type != TensorProto.UNDEFINED
651+
and vi.name
652+
):
653+
vi_copy = ValueInfoProto()
654+
vi_copy.CopyFrom(vi)
655+
if hasattr(vi_copy.type.tensor_type, "shape"):
656+
vi_copy.type.tensor_type.ClearField("shape")
657+
name_vi[vi.name] = vi_copy
653658
for vi in model.graph.value_info:
654659
if vi.name in name_vi:
655660
del name_vi[vi.name]

0 commit comments

Comments
 (0)