Skip to content

Commit 3200e63

Browse files
committed
Fixed the bug of SDXL Cuda Error
1 parent eeaab2b commit 3200e63

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

py/torch_tensorrt/dynamo/_refit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def construct_refit_mapping_from_weight_name_map(
123123
engine_weight_map[engine_weight_name] = (
124124
state_dict[sd_weight_name]
125125
if state_dict[sd_weight_name].device == device
126-
else state_dict[sd_weight_name].to("device")
126+
else state_dict[sd_weight_name].to(device)
127127
)
128128

129129
engine_weight_map[engine_weight_name] = (

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99
import torch_tensorrt
10+
from torch.export._trace import _export
1011
from torch_tensorrt._Device import Device
1112
from torch_tensorrt.dynamo import _defaults
1213
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
@@ -309,7 +310,7 @@ def refit_gm(self) -> None:
309310

310311
def get_exported_program(self) -> torch.export.ExportedProgram:
311312
if self.allow_complex_guards_as_runtime_asserts:
312-
return torch.export._trace._export(
313+
return _export(
313314
self.original_model,
314315
self.arg_inputs,
315316
kwargs=self.kwarg_inputs,

0 commit comments

Comments
 (0)