Skip to content

Commit 5c2803f

Browse files
authored
various fixes for shape inference script (microsoft#2124)
* use dilations for computing effective kernel shape for conv/pool ops * when auto_pad is 'VALID', total_pads should be empty * added support for ArrayFeatureExtractor and ZipMap * check out_shape only if the output has shape, i.e. output is of TensorType or SparseTensorType
1 parent 95ab5ad commit 5c2803f

File tree

1 file changed

+80
-43
lines changed

1 file changed

+80
-43
lines changed

onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -60,45 +60,47 @@ def sympy_reduce_product(x):
6060
class SymbolicShapeInference:
6161
def __init__(self, auto_merge, verbose):
6262
self.dispatcher_ = {
63-
'Add' : self._infer_binary_ops,
64-
'AveragePool' : self._infer_Pool,
65-
'Cast' : self._infer_Cast,
66-
'CategoryMapper' : self._infer_CategoryMapper,
67-
'Compress' : self._infer_Compress,
68-
'Concat' : self._infer_Concat,
69-
'ConstantOfShape' : self._infer_ConstantOfShape,
70-
'Conv' : self._infer_Conv,
71-
'CumSum' : self._pass_on_shape_and_type,
72-
'Div' : self._infer_binary_ops,
73-
'Expand' : self._infer_Expand,
74-
'Gather' : self._infer_Gather,
75-
'GatherElements' : self._infer_GatherElements,
76-
'Loop' : self._infer_Loop,
77-
'MatMulInteger16' : self._infer_MatMulInteger,
78-
'MaxPool' : self._infer_Pool,
79-
'Max' : self._infer_binary_ops,
80-
'Min' : self._infer_binary_ops,
81-
'Mul' : self._infer_binary_ops,
82-
'NonMaxSuppression' : self._infer_NonMaxSuppression,
83-
'NonZero' : self._infer_NonZero,
84-
'OneHot' : self._infer_OneHot,
85-
'Pad' : self._infer_Pad,
86-
'Range' : self._infer_Range,
87-
'ReduceProd' : self._infer_ReduceProd,
88-
'Reshape' : self._infer_Reshape,
89-
'Resize' : self._infer_Resize,
90-
'Round' : self._pass_on_shape_and_type,
91-
'Scan' : self._infer_Scan,
92-
'ScatterElements' : self._infer_ScatterElements,
93-
'Shape' : self._infer_Shape,
94-
'Size' : self._infer_Size,
95-
'Slice' : self._infer_Slice,
96-
'Split' : self._infer_Split,
97-
'Squeeze' : self._infer_Squeeze,
98-
'Sub' : self._infer_binary_ops,
99-
'Tile' : self._infer_Tile,
100-
'TopK' : self._infer_TopK,
101-
'Unsqueeze' : self._infer_Unsqueeze}
63+
'Add' : self._infer_binary_ops,
64+
'ArrayFeatureExtractor' : self._infer_ArrayFeatureExtractor,
65+
'AveragePool' : self._infer_Pool,
66+
'Cast' : self._infer_Cast,
67+
'CategoryMapper' : self._infer_CategoryMapper,
68+
'Compress' : self._infer_Compress,
69+
'Concat' : self._infer_Concat,
70+
'ConstantOfShape' : self._infer_ConstantOfShape,
71+
'Conv' : self._infer_Conv,
72+
'CumSum' : self._pass_on_shape_and_type,
73+
'Div' : self._infer_binary_ops,
74+
'Expand' : self._infer_Expand,
75+
'Gather' : self._infer_Gather,
76+
'GatherElements' : self._infer_GatherElements,
77+
'Loop' : self._infer_Loop,
78+
'MatMulInteger16' : self._infer_MatMulInteger,
79+
'MaxPool' : self._infer_Pool,
80+
'Max' : self._infer_binary_ops,
81+
'Min' : self._infer_binary_ops,
82+
'Mul' : self._infer_binary_ops,
83+
'NonMaxSuppression' : self._infer_NonMaxSuppression,
84+
'NonZero' : self._infer_NonZero,
85+
'OneHot' : self._infer_OneHot,
86+
'Pad' : self._infer_Pad,
87+
'Range' : self._infer_Range,
88+
'ReduceProd' : self._infer_ReduceProd,
89+
'Reshape' : self._infer_Reshape,
90+
'Resize' : self._infer_Resize,
91+
'Round' : self._pass_on_shape_and_type,
92+
'Scan' : self._infer_Scan,
93+
'ScatterElements' : self._infer_ScatterElements,
94+
'Shape' : self._infer_Shape,
95+
'Size' : self._infer_Size,
96+
'Slice' : self._infer_Slice,
97+
'Split' : self._infer_Split,
98+
'Squeeze' : self._infer_Squeeze,
99+
'Sub' : self._infer_binary_ops,
100+
'Tile' : self._infer_Tile,
101+
'TopK' : self._infer_TopK,
102+
'Unsqueeze' : self._infer_Unsqueeze,
103+
'ZipMap' : self._infer_ZipMap}
102104
self.run_ = True
103105
self.suggested_merge_ = {}
104106
self.symbolic_dims_ = {}
@@ -394,13 +396,17 @@ def _compute_conv_pool_shape(self, node):
394396

395397
# only need to symbolic shape inference if input has symbolic dims in spatial axes
396398
is_symbolic_dims = [not is_literal(i) for i in sympy_shape[-rank:]]
399+
397400
if not any(is_symbolic_dims):
398-
sympy_shape[-rank:] = [sympy.Integer(d) for d in get_shape_from_type_proto(self.known_vi_[node.output[0]].type)[-rank:]]
399-
return sympy_shape
401+
shape = get_shape_from_type_proto(self.known_vi_[node.output[0]].type)
402+
if len(shape) > 0:
403+
assert len(sympy_shape) == len(shape)
404+
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
405+
return sympy_shape
400406

401407
dilations = get_attribute(node, 'dilations', [1]*rank)
402408
strides = get_attribute(node, 'strides', [1]*rank)
403-
effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, strides)]
409+
effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
404410
pads = get_attribute(node, 'pads')
405411
if pads is None:
406412
pads = [0]*(2*rank)
@@ -411,6 +417,8 @@ def _compute_conv_pool_shape(self, node):
411417
total_pads = [max(0, (k - s) if r == 0 else (k - r)) for k, s, r in zip(effective_kernel_shape, strides, residual)]
412418
except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
413419
total_pads = [max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides)] # assuming no residual if sympy throws error
420+
elif auto_pad == 'VALID':
421+
total_pads = []
414422
else:
415423
total_pads = [0]*rank
416424
else:
@@ -419,14 +427,24 @@ def _compute_conv_pool_shape(self, node):
419427

420428
ceil_mode = get_attribute(node, 'ceil_mode', 0)
421429
for i in range(rank):
422-
effective_input_size = sympy_shape[-rank + i] + total_pads[i]
430+
effective_input_size = sympy_shape[-rank + i]
431+
if len(total_pads) > 0:
432+
effective_input_size = effective_input_size + total_pads[i]
423433
if ceil_mode:
424434
strided_kernel_positions = sympy.ceiling((effective_input_size - effective_kernel_shape[i]) / strides[i])
425435
else:
426436
strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i]
427437
sympy_shape[-rank + i] = strided_kernel_positions + 1
428438
return sympy_shape
429439

440+
def _infer_ArrayFeatureExtractor(self, node):
441+
data_shape = self._get_shape(node, 0)
442+
indices_shape = self._get_shape(node, 1)
443+
vi = self.known_vi_[node.output[0]]
444+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0],
445+
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
446+
data_shape[:-1] + indices_shape))
447+
430448
def _infer_binary_ops(self, node):
431449
funcs = {'Add' : lambda l: l[0] + l[1],
432450
'Div' : lambda l: l[0] // l[1], # integer div in sympy
@@ -874,6 +892,21 @@ def _infer_TopK(self, node):
874892
def _infer_Unsqueeze(self, node):
875893
self._pass_on_sympy_data(node)
876894

895+
def _infer_ZipMap(self, node):
896+
map_key_type = None
897+
if get_attribute(node, 'classlabels_int64s') is not None:
898+
map_key_type = onnx.TensorProto.INT64
899+
elif get_attribute(node, 'classlabels_strings') is not None:
900+
map_key_type = onnx.TensorProto.STRING
901+
902+
assert map_key_type is not None
903+
new_vi = onnx.ValueInfoProto()
904+
new_vi.name = node.output[0]
905+
new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
906+
new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
907+
vi = self.known_vi_[node.output[0]]
908+
vi.CopyFrom(new_vi)
909+
877910
def _infer_impl(self, in_mp):
878911
self.sympy_data_ = {}
879912
self.out_mp_.graph.ClearField('value_info')
@@ -906,6 +939,10 @@ def _infer_impl(self, in_mp):
906939
print(node.op_type + ': ' + node.name)
907940
for i_o in range(len(node.output)):
908941
out_type = self.known_vi_[node.output[i_o]].type
942+
out_type_kind = out_type.WhichOneof('value')
943+
# only TensorProto and SparseTensorProto have shape
944+
if out_type_kind != 'tensor_type' and out_type_kind != 'sparse_tensor_type':
945+
continue
909946
out_shape = get_shape_from_type_proto(self.known_vi_[node.output[i_o]].type)
910947
out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
911948
if self.verbose_ > 2:

0 commit comments

Comments
 (0)