Skip to content

Enabled refit on Python 3.13 #3481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
_TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path)
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
_FX_FE_AVAIL = True
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13")
_REFIT_AVAIL = True

ENABLED_FEATURES = FeatureSet(
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
Expand Down
48 changes: 9 additions & 39 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
from typing import Any, List, Optional, Sequence, Tuple

import numpy as np
import tensorrt as trt
import torch
from torch.export import ExportedProgram
Expand Down Expand Up @@ -53,7 +52,7 @@ def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
) -> dict[str, np.ndarray]:
) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not be a little more specific?

"""Find out the weight mapping between weight in exported program and TensorRT engine
Args:
module: FX GraphModule to interpret
Expand All @@ -62,26 +61,13 @@ def construct_refit_mapping(
Returns:
Mapping from weight name in TensorRT to actual weight value in np.ndarray
"""
MODULE_MAP = {
"SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]),
"CONVOLUTION": (
trt.IConvolutionLayer,
[("kernel", "KERNEL"), ("bias", "BIAS")],
),
"DECONVOLUTION": (
trt.IDeconvolutionLayer,
[("kernel", "KERNEL"), ("bias", "BIAS")],
),
"CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]),
}

output_dtypes = infer_module_output_dtypes(
module,
truncate_double=settings.truncate_double,
)

# Use Interpreter
weight_map = {}
interpreter = TRTInterpreter(
module,
inputs,
Expand All @@ -90,24 +76,8 @@ def construct_refit_mapping(
compilation_settings=settings,
)
interpreter._construct_trt_network_def()
net = interpreter.ctx.net
for i in range(net.num_layers):
layer = net[i]
layer_type: str = layer.type.name
if layer_type in MODULE_MAP:
# Cast the parent class to child class to access attributes
# For example: ILayer does not have ILayer.kernel/ILayer.bias
# So we cast it to IConvolutionLayer and access the attributes
layer.__class__ = MODULE_MAP[layer_type][0]
for weight_type, weight_name in MODULE_MAP[layer_type][1]:
weight = layer.__getattribute__(weight_type).copy()
weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType)
weight_map[f"{layer.name} {weight_name}"] = (
weight,
weight_dtype,
)

return weight_map
return interpreter.ctx.mapping


@needs_refit
Expand All @@ -118,13 +88,12 @@ def construct_refit_mapping_from_weight_name_map(
) -> dict[Any, Any]:
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)

if sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
to_torch_device(settings.device)
)
Expand Down Expand Up @@ -178,8 +147,8 @@ def _refit_single_trt_engine_with_gm(
for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
Expand Down Expand Up @@ -208,8 +177,9 @@ def _refit_single_trt_engine_with_gm(
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
weight = mapping[layer_name]
trt_dtype = dtype._from(weight.dtype).to(trt.DataType)
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Dict

from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.fx.types import TRTNetwork
Expand All @@ -19,3 +20,4 @@ class ConversionContext:
default_factory=CompilationSettings
)
requires_output_allocator: bool = False
mapping: Dict[str, Any] = field(default_factory=dict)
14 changes: 4 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,19 +498,15 @@ def _save_weight_mapping(self) -> None:
for k, v in self.module.state_dict().items()
}
weight_name_map: dict[str, Any] = {}
np_map = {}
constant_mapping = {}
np_map = self.ctx.mapping
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
net = self.ctx.net
for i in range(net.num_layers):
layer = net[i]
layer_type: str = layer.type.name
if layer_type in MODULE_MAP:
layer.__class__ = MODULE_MAP[layer_type][0]
# Name mapping
for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]:
weight = layer.__getattribute__(weight_type).copy()
if weight.size == 0:
continue
engine_weight_name = f"{layer.name} {weight_name}"
# Infer the corresponding weight name(s) in state_dict
sd_weight_name_list = (
Expand Down Expand Up @@ -538,17 +534,15 @@ def _save_weight_mapping(self) -> None:
elif "bias" in suffix:
sd_weight_name = f"{sd_weight_name}.bias"
else:
# Save the constant weights for future fast refit
sd_weight_name = f"{sd_weight_name}.unknown"
constant_mapping[engine_weight_name] = weight
elif layer_type == "SCALE":
# Batch norm needs all weights to calculate scale and shift
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]
else:
sd_weight_name = f"{sd_weight_name}.{torch_attr}"

weight_name_map[engine_weight_name] = sd_weight_name
np_map[engine_weight_name] = weight
if engine_weight_name in np_map:
weight_name_map[engine_weight_name] = sd_weight_name

# Stage 2: Value mapping
for engine_weight_name, sd_weight_name in weight_name_map.items():
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def create_constant(
else:
numpy_value = torch_value.numpy()

ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
constant = ctx.net.add_constant(
shape,
numpy_value,
Expand Down
Loading