Skip to content

Commit 792e518

Browse files
committed
Revised according to comments
1 parent 3f6b6ab commit 792e518

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

py/torch_tensorrt/_features.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
3838
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
3939
_FX_FE_AVAIL = True
40-
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.14")
40+
_REFIT_AVAIL = True
4141

4242
ENABLED_FEATURES = FeatureSet(
4343
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL

py/torch_tensorrt/dynamo/_refit.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def construct_refit_mapping_from_weight_name_map(
9393
# If weights is not in sd, we can leave it unchanged
9494
continue
9595
else:
96-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
97-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
96+
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
97+
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
9898
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
9999
to_torch_device(settings.device)
100100
)
@@ -148,8 +148,8 @@ def _refit_single_trt_engine_with_gm(
148148
for constant_name, val in constant_mapping.items():
149149
np_weight_type = val.dtype
150150
val_tensor = torch.from_numpy(val).cuda()
151-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
152-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
151+
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
152+
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
153153
constant_mapping_with_type[constant_name] = (
154154
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
155155
trt_dtype,
@@ -179,7 +179,7 @@ def _refit_single_trt_engine_with_gm(
179179
raise AssertionError(f"{layer_name} is not found in weight mapping")
180180
# Use Numpy to create weights
181181
weight = mapping[layer_name]
182-
trt_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
182+
trt_dtype = dtype._from(weight.dtype).to(trt.DataType)
183183
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
184184
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
185185
refitted.add(layer_name)

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass, field
2-
from typing import Any, Dict
2+
from typing import Dict
33

4+
import numpy as np
45
from torch_tensorrt.dynamo._settings import CompilationSettings
56
from torch_tensorrt.fx.types import TRTNetwork
67

@@ -20,4 +21,4 @@ class ConversionContext:
2021
default_factory=CompilationSettings
2122
)
2223
requires_output_allocator: bool = False
23-
mapping: Dict[str, Any] = field(default_factory=dict)
24+
mapping: Dict[str, np.array] = field(default_factory=dict)

0 commit comments

Comments
 (0)