Skip to content

Commit eaca0cc

Browse files
committed
Enabled refit on Python 3.13
1 parent 79083b6 commit eaca0cc

File tree

5 files changed

+14
-46
lines changed

5 files changed

+14
-46
lines changed

py/torch_tensorrt/_features.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
3838
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
3939
_FX_FE_AVAIL = True
40-
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13")
40+
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.14")
4141

4242
ENABLED_FEATURES = FeatureSet(
4343
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL

py/torch_tensorrt/dynamo/_refit.py

+6-35
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,13 @@ def construct_refit_mapping(
6262
Returns:
6363
Mapping from weight name in TensorRT to actual weight value in np.ndarray
6464
"""
65-
MODULE_MAP = {
66-
"SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]),
67-
"CONVOLUTION": (
68-
trt.IConvolutionLayer,
69-
[("kernel", "KERNEL"), ("bias", "BIAS")],
70-
),
71-
"DECONVOLUTION": (
72-
trt.IDeconvolutionLayer,
73-
[("kernel", "KERNEL"), ("bias", "BIAS")],
74-
),
75-
"CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]),
76-
}
7765

7866
output_dtypes = infer_module_output_dtypes(
7967
module,
8068
truncate_double=settings.truncate_double,
8169
)
8270

8371
# Use Interpreter
84-
weight_map = {}
8572
interpreter = TRTInterpreter(
8673
module,
8774
inputs,
@@ -90,24 +77,8 @@ def construct_refit_mapping(
9077
compilation_settings=settings,
9178
)
9279
interpreter._construct_trt_network_def()
93-
net = interpreter.ctx.net
94-
for i in range(net.num_layers):
95-
layer = net[i]
96-
layer_type: str = layer.type.name
97-
if layer_type in MODULE_MAP:
98-
# Cast the parent class to child class to access attributes
99-
# For example: ILayer does not have ILayer.kernel/ILayer.bias
100-
# So we cast it to IConvolutionLayer and access the attributes
101-
layer.__class__ = MODULE_MAP[layer_type][0]
102-
for weight_type, weight_name in MODULE_MAP[layer_type][1]:
103-
weight = layer.__getattribute__(weight_type).copy()
104-
weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
105-
weight_map[f"{layer.name} {weight_name}"] = (
106-
weight,
107-
weight_dtype,
108-
)
10980

110-
return weight_map
81+
return interpreter.ctx.mapping
11182

11283

11384
@needs_refit
@@ -118,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map(
11889
) -> dict[Any, Any]:
11990
engine_weight_map = {}
12091
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
121-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
122-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
123-
12492
if sd_weight_name not in state_dict:
12593
# If weights is not in sd, we can leave it unchanged
12694
continue
12795
else:
96+
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
97+
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
12898
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
12999
to_torch_device(settings.device)
130100
)
@@ -208,8 +178,9 @@ def _refit_single_trt_engine_with_gm(
208178
if layer_name not in mapping:
209179
raise AssertionError(f"{layer_name} is not found in weight mapping")
210180
# Use Numpy to create weights
211-
weight, datatype = mapping[layer_name]
212-
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
181+
weight = mapping[layer_name]
182+
trt_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
183+
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
213184
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
214185
refitted.add(layer_name)
215186

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass, field
2+
from typing import Any, Dict
23

34
from torch_tensorrt.dynamo._settings import CompilationSettings
45
from torch_tensorrt.fx.types import TRTNetwork
@@ -19,3 +20,4 @@ class ConversionContext:
1920
default_factory=CompilationSettings
2021
)
2122
requires_output_allocator: bool = False
23+
mapping: Dict[str, Any] = field(default_factory=dict)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -498,19 +498,15 @@ def _save_weight_mapping(self) -> None:
498498
for k, v in self.module.state_dict().items()
499499
}
500500
weight_name_map: dict[str, Any] = {}
501-
np_map = {}
502-
constant_mapping = {}
501+
np_map = self.ctx.mapping
502+
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
503503
net = self.ctx.net
504504
for i in range(net.num_layers):
505505
layer = net[i]
506506
layer_type: str = layer.type.name
507507
if layer_type in MODULE_MAP:
508-
layer.__class__ = MODULE_MAP[layer_type][0]
509508
# Name mapping
510509
for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]:
511-
weight = layer.__getattribute__(weight_type).copy()
512-
if weight.size == 0:
513-
continue
514510
engine_weight_name = f"{layer.name} {weight_name}"
515511
# Infer the corresponding weight name(s) in state_dict
516512
sd_weight_name_list = (
@@ -538,17 +534,15 @@ def _save_weight_mapping(self) -> None:
538534
elif "bias" in suffix:
539535
sd_weight_name = f"{sd_weight_name}.bias"
540536
else:
541-
# Save the constant weights for future fast refit
542537
sd_weight_name = f"{sd_weight_name}.unknown"
543-
constant_mapping[engine_weight_name] = weight
544538
elif layer_type == "SCALE":
545539
# Batch norm needs all weights to calculate scale and shift
546540
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]
547541
else:
548542
sd_weight_name = f"{sd_weight_name}.{torch_attr}"
549543

550-
weight_name_map[engine_weight_name] = sd_weight_name
551-
np_map[engine_weight_name] = weight
544+
if engine_weight_name in np_map:
545+
weight_name_map[engine_weight_name] = sd_weight_name
552546

553547
# Stage 2: Value mapping
554548
for engine_weight_name, sd_weight_name in weight_name_map.items():

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def create_constant(
363363
else:
364364
numpy_value = torch_value.numpy()
365365

366+
ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
366367
constant = ctx.net.add_constant(
367368
shape,
368369
numpy_value,

0 commit comments

Comments
 (0)