@@ -60,45 +60,47 @@ def sympy_reduce_product(x):
60
60
class SymbolicShapeInference :
61
61
def __init__ (self , auto_merge , verbose ):
62
62
self .dispatcher_ = {
63
- 'Add' : self ._infer_binary_ops ,
64
- 'AveragePool' : self ._infer_Pool ,
65
- 'Cast' : self ._infer_Cast ,
66
- 'CategoryMapper' : self ._infer_CategoryMapper ,
67
- 'Compress' : self ._infer_Compress ,
68
- 'Concat' : self ._infer_Concat ,
69
- 'ConstantOfShape' : self ._infer_ConstantOfShape ,
70
- 'Conv' : self ._infer_Conv ,
71
- 'CumSum' : self ._pass_on_shape_and_type ,
72
- 'Div' : self ._infer_binary_ops ,
73
- 'Expand' : self ._infer_Expand ,
74
- 'Gather' : self ._infer_Gather ,
75
- 'GatherElements' : self ._infer_GatherElements ,
76
- 'Loop' : self ._infer_Loop ,
77
- 'MatMulInteger16' : self ._infer_MatMulInteger ,
78
- 'MaxPool' : self ._infer_Pool ,
79
- 'Max' : self ._infer_binary_ops ,
80
- 'Min' : self ._infer_binary_ops ,
81
- 'Mul' : self ._infer_binary_ops ,
82
- 'NonMaxSuppression' : self ._infer_NonMaxSuppression ,
83
- 'NonZero' : self ._infer_NonZero ,
84
- 'OneHot' : self ._infer_OneHot ,
85
- 'Pad' : self ._infer_Pad ,
86
- 'Range' : self ._infer_Range ,
87
- 'ReduceProd' : self ._infer_ReduceProd ,
88
- 'Reshape' : self ._infer_Reshape ,
89
- 'Resize' : self ._infer_Resize ,
90
- 'Round' : self ._pass_on_shape_and_type ,
91
- 'Scan' : self ._infer_Scan ,
92
- 'ScatterElements' : self ._infer_ScatterElements ,
93
- 'Shape' : self ._infer_Shape ,
94
- 'Size' : self ._infer_Size ,
95
- 'Slice' : self ._infer_Slice ,
96
- 'Split' : self ._infer_Split ,
97
- 'Squeeze' : self ._infer_Squeeze ,
98
- 'Sub' : self ._infer_binary_ops ,
99
- 'Tile' : self ._infer_Tile ,
100
- 'TopK' : self ._infer_TopK ,
101
- 'Unsqueeze' : self ._infer_Unsqueeze }
63
+ 'Add' : self ._infer_binary_ops ,
64
+ 'ArrayFeatureExtractor' : self ._infer_ArrayFeatureExtractor ,
65
+ 'AveragePool' : self ._infer_Pool ,
66
+ 'Cast' : self ._infer_Cast ,
67
+ 'CategoryMapper' : self ._infer_CategoryMapper ,
68
+ 'Compress' : self ._infer_Compress ,
69
+ 'Concat' : self ._infer_Concat ,
70
+ 'ConstantOfShape' : self ._infer_ConstantOfShape ,
71
+ 'Conv' : self ._infer_Conv ,
72
+ 'CumSum' : self ._pass_on_shape_and_type ,
73
+ 'Div' : self ._infer_binary_ops ,
74
+ 'Expand' : self ._infer_Expand ,
75
+ 'Gather' : self ._infer_Gather ,
76
+ 'GatherElements' : self ._infer_GatherElements ,
77
+ 'Loop' : self ._infer_Loop ,
78
+ 'MatMulInteger16' : self ._infer_MatMulInteger ,
79
+ 'MaxPool' : self ._infer_Pool ,
80
+ 'Max' : self ._infer_binary_ops ,
81
+ 'Min' : self ._infer_binary_ops ,
82
+ 'Mul' : self ._infer_binary_ops ,
83
+ 'NonMaxSuppression' : self ._infer_NonMaxSuppression ,
84
+ 'NonZero' : self ._infer_NonZero ,
85
+ 'OneHot' : self ._infer_OneHot ,
86
+ 'Pad' : self ._infer_Pad ,
87
+ 'Range' : self ._infer_Range ,
88
+ 'ReduceProd' : self ._infer_ReduceProd ,
89
+ 'Reshape' : self ._infer_Reshape ,
90
+ 'Resize' : self ._infer_Resize ,
91
+ 'Round' : self ._pass_on_shape_and_type ,
92
+ 'Scan' : self ._infer_Scan ,
93
+ 'ScatterElements' : self ._infer_ScatterElements ,
94
+ 'Shape' : self ._infer_Shape ,
95
+ 'Size' : self ._infer_Size ,
96
+ 'Slice' : self ._infer_Slice ,
97
+ 'Split' : self ._infer_Split ,
98
+ 'Squeeze' : self ._infer_Squeeze ,
99
+ 'Sub' : self ._infer_binary_ops ,
100
+ 'Tile' : self ._infer_Tile ,
101
+ 'TopK' : self ._infer_TopK ,
102
+ 'Unsqueeze' : self ._infer_Unsqueeze ,
103
+ 'ZipMap' : self ._infer_ZipMap }
102
104
self .run_ = True
103
105
self .suggested_merge_ = {}
104
106
self .symbolic_dims_ = {}
@@ -394,13 +396,17 @@ def _compute_conv_pool_shape(self, node):
394
396
395
397
# only need to symbolic shape inference if input has symbolic dims in spatial axes
396
398
is_symbolic_dims = [not is_literal (i ) for i in sympy_shape [- rank :]]
399
+
397
400
if not any (is_symbolic_dims ):
398
- sympy_shape [- rank :] = [sympy .Integer (d ) for d in get_shape_from_type_proto (self .known_vi_ [node .output [0 ]].type )[- rank :]]
399
- return sympy_shape
401
+ shape = get_shape_from_type_proto (self .known_vi_ [node .output [0 ]].type )
402
+ if len (shape ) > 0 :
403
+ assert len (sympy_shape ) == len (shape )
404
+ sympy_shape [- rank :] = [sympy .Integer (d ) for d in shape [- rank :]]
405
+ return sympy_shape
400
406
401
407
dilations = get_attribute (node , 'dilations' , [1 ]* rank )
402
408
strides = get_attribute (node , 'strides' , [1 ]* rank )
403
- effective_kernel_shape = [(k - 1 ) * d + 1 for k , d in zip (kernel_shape , strides )]
409
+ effective_kernel_shape = [(k - 1 ) * d + 1 for k , d in zip (kernel_shape , dilations )]
404
410
pads = get_attribute (node , 'pads' )
405
411
if pads is None :
406
412
pads = [0 ]* (2 * rank )
@@ -411,6 +417,8 @@ def _compute_conv_pool_shape(self, node):
411
417
total_pads = [max (0 , (k - s ) if r == 0 else (k - r )) for k , s , r in zip (effective_kernel_shape , strides , residual )]
412
418
except TypeError : # sympy may throw TypeError: cannot determine truth value of Relational
413
419
total_pads = [max (0 , (k - s )) for k , s in zip (effective_kernel_shape , strides )] # assuming no residual if sympy throws error
420
+ elif auto_pad == 'VALID' :
421
+ total_pads = []
414
422
else :
415
423
total_pads = [0 ]* rank
416
424
else :
@@ -419,14 +427,24 @@ def _compute_conv_pool_shape(self, node):
419
427
420
428
ceil_mode = get_attribute (node , 'ceil_mode' , 0 )
421
429
for i in range (rank ):
422
- effective_input_size = sympy_shape [- rank + i ] + total_pads [i ]
430
+ effective_input_size = sympy_shape [- rank + i ]
431
+ if len (total_pads ) > 0 :
432
+ effective_input_size = effective_input_size + total_pads [i ]
423
433
if ceil_mode :
424
434
strided_kernel_positions = sympy .ceiling ((effective_input_size - effective_kernel_shape [i ]) / strides [i ])
425
435
else :
426
436
strided_kernel_positions = (effective_input_size - effective_kernel_shape [i ]) // strides [i ]
427
437
sympy_shape [- rank + i ] = strided_kernel_positions + 1
428
438
return sympy_shape
429
439
440
+ def _infer_ArrayFeatureExtractor (self , node ):
441
+ data_shape = self ._get_shape (node , 0 )
442
+ indices_shape = self ._get_shape (node , 1 )
443
+ vi = self .known_vi_ [node .output [0 ]]
444
+ vi .CopyFrom (helper .make_tensor_value_info (node .output [0 ],
445
+ self .known_vi_ [node .input [0 ]].type .tensor_type .elem_type ,
446
+ data_shape [:- 1 ] + indices_shape ))
447
+
430
448
def _infer_binary_ops (self , node ):
431
449
funcs = {'Add' : lambda l : l [0 ] + l [1 ],
432
450
'Div' : lambda l : l [0 ] // l [1 ], # integer div in sympy
@@ -874,6 +892,21 @@ def _infer_TopK(self, node):
874
892
def _infer_Unsqueeze (self , node ):
875
893
self ._pass_on_sympy_data (node )
876
894
895
+ def _infer_ZipMap (self , node ):
896
+ map_key_type = None
897
+ if get_attribute (node , 'classlabels_int64s' ) is not None :
898
+ map_key_type = onnx .TensorProto .INT64
899
+ elif get_attribute (node , 'classlabels_strings' ) is not None :
900
+ map_key_type = onnx .TensorProto .STRING
901
+
902
+ assert map_key_type is not None
903
+ new_vi = onnx .ValueInfoProto ()
904
+ new_vi .name = node .output [0 ]
905
+ new_vi .type .sequence_type .elem_type .map_type .value_type .tensor_type .elem_type = onnx .TensorProto .FLOAT
906
+ new_vi .type .sequence_type .elem_type .map_type .key_type = map_key_type
907
+ vi = self .known_vi_ [node .output [0 ]]
908
+ vi .CopyFrom (new_vi )
909
+
877
910
def _infer_impl (self , in_mp ):
878
911
self .sympy_data_ = {}
879
912
self .out_mp_ .graph .ClearField ('value_info' )
@@ -906,6 +939,10 @@ def _infer_impl(self, in_mp):
906
939
print (node .op_type + ': ' + node .name )
907
940
for i_o in range (len (node .output )):
908
941
out_type = self .known_vi_ [node .output [i_o ]].type
942
+ out_type_kind = out_type .WhichOneof ('value' )
943
+ # only TensorProto and SparseTensorProto have shape
944
+ if out_type_kind != 'tensor_type' and out_type_kind != 'sparse_tensor_type' :
945
+ continue
909
946
out_shape = get_shape_from_type_proto (self .known_vi_ [node .output [i_o ]].type )
910
947
out_type_undefined = out_type .tensor_type .elem_type == onnx .TensorProto .UNDEFINED
911
948
if self .verbose_ > 2 :
0 commit comments