@@ -62,26 +62,13 @@ def construct_refit_mapping(
62
62
Returns:
63
63
Mapping from weight name in TensorRT to actual weight value in np.ndarray
64
64
"""
65
- MODULE_MAP = {
66
- "SCALE" : (trt .IScaleLayer , [("scale" , "SCALE" ), ("shift" , "SHIFT" )]),
67
- "CONVOLUTION" : (
68
- trt .IConvolutionLayer ,
69
- [("kernel" , "KERNEL" ), ("bias" , "BIAS" )],
70
- ),
71
- "DECONVOLUTION" : (
72
- trt .IDeconvolutionLayer ,
73
- [("kernel" , "KERNEL" ), ("bias" , "BIAS" )],
74
- ),
75
- "CONSTANT" : (trt .IConstantLayer , [("weights" , "CONSTANT" )]),
76
- }
77
65
78
66
output_dtypes = infer_module_output_dtypes (
79
67
module ,
80
68
truncate_double = settings .truncate_double ,
81
69
)
82
70
83
71
# Use Interpreter
84
- weight_map = {}
85
72
interpreter = TRTInterpreter (
86
73
module ,
87
74
inputs ,
@@ -90,24 +77,8 @@ def construct_refit_mapping(
90
77
compilation_settings = settings ,
91
78
)
92
79
interpreter ._construct_trt_network_def ()
93
- net = interpreter .ctx .net
94
- for i in range (net .num_layers ):
95
- layer = net [i ]
96
- layer_type : str = layer .type .name
97
- if layer_type in MODULE_MAP :
98
- # Cast the parent class to child class to access attributes
99
- # For example: ILayer does not have ILayer.kernel/ILayer.bias
100
- # So we cast it to IConvolutionLayer and access the attributes
101
- layer .__class__ = MODULE_MAP [layer_type ][0 ]
102
- for weight_type , weight_name in MODULE_MAP [layer_type ][1 ]:
103
- weight = layer .__getattribute__ (weight_type ).copy ()
104
- weight_dtype = dtype .try_from (weight .dtype ).to (trt .DataType )
105
- weight_map [f"{ layer .name } { weight_name } " ] = (
106
- weight ,
107
- weight_dtype ,
108
- )
109
80
110
- return weight_map
81
+ return interpreter . ctx . mapping
111
82
112
83
113
84
@needs_refit
@@ -118,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map(
118
89
) -> dict [Any , Any ]:
119
90
engine_weight_map = {}
120
91
for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
121
- trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
122
- torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
123
-
124
92
if sd_weight_name not in state_dict :
125
93
# If weights is not in sd, we can leave it unchanged
126
94
continue
127
95
else :
96
+ trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
97
+ torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
128
98
engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ].to (
129
99
to_torch_device (settings .device )
130
100
)
@@ -208,8 +178,9 @@ def _refit_single_trt_engine_with_gm(
208
178
if layer_name not in mapping :
209
179
raise AssertionError (f"{ layer_name } is not found in weight mapping" )
210
180
# Use Numpy to create weights
211
- weight , datatype = mapping [layer_name ]
212
- trt_wt_tensor = trt .Weights (datatype , weight .ctypes .data , weight .size )
181
+ weight = mapping [layer_name ]
182
+ trt_dtype = dtype .try_from (weight .dtype ).to (trt .DataType )
183
+ trt_wt_tensor = trt .Weights (trt_dtype , weight .ctypes .data , weight .size )
213
184
refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
214
185
refitted .add (layer_name )
215
186
0 commit comments