Skip to content

Commit 9698dd5

Browse files
committed
Revised according to comments
1 parent 3f6b6ab commit 9698dd5

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

py/torch_tensorrt/_features.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
3838
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
3939
_FX_FE_AVAIL = True
40-
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.14")
40+
_REFIT_AVAIL = True
4141

4242
ENABLED_FEATURES = FeatureSet(
4343
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL

py/torch_tensorrt/dynamo/_refit.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import logging
66
from typing import Any, List, Optional, Sequence, Tuple
77

8-
import numpy as np
98
import tensorrt as trt
109
import torch
1110
from torch.export import ExportedProgram
@@ -53,7 +52,7 @@ def construct_refit_mapping(
5352
module: torch.fx.GraphModule,
5453
inputs: Sequence[Input],
5554
settings: CompilationSettings = CompilationSettings(),
56-
) -> dict[str, np.ndarray]:
55+
) -> Any:
5756
"""Find out the weight mapping between weight in exported program and TensorRT engine
5857
Args:
5958
module: FX GraphModule to interpret
@@ -93,8 +92,8 @@ def construct_refit_mapping_from_weight_name_map(
9392
# If weights is not in sd, we can leave it unchanged
9493
continue
9594
else:
96-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
97-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
95+
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
96+
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
9897
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
9998
to_torch_device(settings.device)
10099
)
@@ -148,8 +147,8 @@ def _refit_single_trt_engine_with_gm(
148147
for constant_name, val in constant_mapping.items():
149148
np_weight_type = val.dtype
150149
val_tensor = torch.from_numpy(val).cuda()
151-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
152-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
150+
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
151+
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
153152
constant_mapping_with_type[constant_name] = (
154153
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
155154
trt_dtype,
@@ -179,7 +178,7 @@ def _refit_single_trt_engine_with_gm(
179178
raise AssertionError(f"{layer_name} is not found in weight mapping")
180179
# Use Numpy to create weights
181180
weight = mapping[layer_name]
182-
trt_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
181+
trt_dtype = dtype._from(weight.dtype).to(trt.DataType)
183182
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
184183
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
185184
refitted.add(layer_name)

0 commit comments

Comments
 (0)