@@ -109,7 +109,7 @@ def construct_refit_mapping(
109109
110110
111111def construct_refit_mapping_from_weight_name_map (
112- weight_name_map : dict [Any , Any ], state_dict : dict [Any , Any ]
112+ weight_name_map : dict [Any , Any ], state_dict : dict [Any , Any ], device : torch . device
113113) -> dict [Any , Any ]:
114114 engine_weight_map = {}
115115 for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
@@ -120,7 +120,11 @@ def construct_refit_mapping_from_weight_name_map(
120120 # If weights is not in sd, we can leave it unchanged
121121 continue
122122 else :
123- engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ]
123+ engine_weight_map [engine_weight_name ] = (
124+ state_dict [sd_weight_name ]
125+ if state_dict [sd_weight_name ].device == device
126+ else state_dict [sd_weight_name ].to ("device" )
127+ )
124128
125129 engine_weight_map [engine_weight_name ] = (
126130 engine_weight_map [engine_weight_name ]
@@ -162,7 +166,7 @@ def _refit_single_trt_engine_with_gm(
162166 "constant_mapping" , {}
163167 ) # type: ignore
164168 mapping = construct_refit_mapping_from_weight_name_map (
165- weight_name_map , new_gm .state_dict ()
169+ weight_name_map , new_gm .state_dict (), torch_device
166170 )
167171 constant_mapping_with_type = {}
168172
0 commit comments