File tree 2 files changed +3
-2
lines changed
2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -123,7 +123,7 @@ def construct_refit_mapping_from_weight_name_map(
123
123
engine_weight_map [engine_weight_name ] = (
124
124
state_dict [sd_weight_name ]
125
125
if state_dict [sd_weight_name ].device == device
126
- else state_dict [sd_weight_name ].to (" device" )
126
+ else state_dict [sd_weight_name ].to (device )
127
127
)
128
128
129
129
engine_weight_map [engine_weight_name ] = (
Original file line number Diff line number Diff line change 7
7
import numpy as np
8
8
import torch
9
9
import torch_tensorrt
10
+ from torch .export ._trace import _export
10
11
from torch_tensorrt ._Device import Device
11
12
from torch_tensorrt .dynamo import _defaults
12
13
from torch_tensorrt .dynamo ._compiler import compile as dynamo_compile
@@ -309,7 +310,7 @@ def refit_gm(self) -> None:
309
310
310
311
def get_exported_program (self ) -> torch .export .ExportedProgram :
311
312
if self .allow_complex_guards_as_runtime_asserts :
312
- return torch . export . _trace . _export (
313
+ return _export (
313
314
self .original_model ,
314
315
self .arg_inputs ,
315
316
kwargs = self .kwarg_inputs ,
You can’t perform that action at this time.
0 commit comments