Skip to content

fix: Change the translational layer from numpy to torch during conversion to handle additional data types #3445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 1, 2025
122 changes: 62 additions & 60 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorrt as trt
import torch
from torch.export import ExportedProgram
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import partitioning
Expand Down Expand Up @@ -144,71 +145,72 @@ def _refit_single_trt_engine_with_gm(
Refit a TensorRT Engine in place
"""

refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)
with unset_fake_temporarily():
refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)

constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
)

mapping.update(constant_mapping_with_type)
mapping.update(constant_mapping_with_type)

for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"
for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")
else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")


def refit_module_weights(
Expand Down
40 changes: 20 additions & 20 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorrt as trt
import torch
import torch.fx
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._python_dispatch import _disable_current_modes
Expand All @@ -41,6 +42,7 @@
get_node_io,
get_node_name,
get_trt_tensor,
to_torch,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
from torch_tensorrt.fx.observer import Observer
Expand Down Expand Up @@ -408,27 +410,29 @@ def find_weight(
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""
with unset_fake_temporarily():
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor,
network_weight: Union[torch.Tensor, np.ndarray],
device: torch.device,
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)
with unset_fake_temporarily():
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)

def _save_weight_mapping(self) -> None:
"""
Expand Down Expand Up @@ -887,19 +891,15 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
return converter(self.ctx, target, args, kwargs, self._cur_node_name)

def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
with _disable_current_modes():
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy

with _disable_current_modes(), unset_fake_temporarily():
frozen_attr = self.fetch_attr(target)

if isinstance(frozen_attr, torch.nn.Parameter):
constant_tensor = frozen_attr.data
else:
constant_tensor = frozen_attr

network_constant = to_numpy(constant_tensor)

return network_constant
return to_torch(constant_tensor)

def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
assert isinstance(target, str)
Expand Down
103 changes: 92 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt import _enums
Expand Down Expand Up @@ -340,17 +341,47 @@ def create_constant(
Returns:
A TensorRT ITensor that represents the given value.
"""
shape = (1,)
# Rank 0 constant is required in IFillLayer inputs.
if min_rank == 0:
shape = trt.Dims()
numpy_value = to_numpy(value, dtype)
constant = ctx.net.add_constant(
shape if isinstance(value, (int, float, bool)) else value.shape,
numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
)
constant.name = name
return constant.get_output(0)
with unset_fake_temporarily():

torch_value = to_torch(value, dtype)
if torch_value.dtype == torch.float64:
raise ValueError(
"TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model."
)
# Rank 0 constant is required in IFillLayer inputs.
if min_rank == 0 and isinstance(value, (int, float, bool)):
shape = trt.Dims()
elif list(torch_value.shape) == []:
shape = trt.Dims()
else:
shape = list(torch_value.shape)

if torch_value is not None:
if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
numpy_value = torch_value.numpy()

constant = ctx.net.add_constant(
shape,
numpy_value,
)
constant.name = name

if torch_value.dtype == torch.bfloat16:
return cast_trt_tensor(
ctx,
constant.get_output(0),
trt.DataType.BF16,
name + "_bf16_cast",
)

return constant.get_output(0)
else:
raise ValueError(
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
)


def get_trt_tensor(
Expand Down Expand Up @@ -564,6 +595,9 @@ def to_numpy(
value = value.dequantize()
elif value.dtype == torch.bfloat16:
# TODO: Remove when numpy has a BF16 type
_LOGGER.warning(
"Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation",
)
value = value.to(torch.float)

output = value.cpu().detach().contiguous().numpy()
Expand All @@ -589,6 +623,53 @@ def to_numpy(
)


def to_torch(
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
) -> Optional[torch.Tensor]:
"""
Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
Args:
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
A PyTorch tensor, Numpy array, int, float, or bool
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
If a dtype is given, we will convert the type of the given `value` to this dtype.
Returns:
A PyTorch tensor or None, if the input was None.
"""

cpu_device = torch.device("cpu")
torch_dtype = (
_enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None
)

with unset_fake_temporarily():
if value is None:
return None

elif isinstance(value, torch.Tensor):
output = value.to(cpu_device).contiguous()

elif isinstance(value, np.ndarray):
output = torch.from_numpy(value).to(cpu_device).contiguous()

elif isinstance(value, int):
output = torch.tensor([value], device=cpu_device, dtype=torch.int32)

elif isinstance(value, float):
output = torch.tensor([value], device=cpu_device, dtype=torch.float32)

elif isinstance(value, bool):
output = torch.tensor([value], device=cpu_device, dtype=torch.bool)

else:
raise AssertionError(
f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}"
)

return output.to(torch_dtype) if torch_dtype else output


def flatten_dims(
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
start_dim: int,
Expand Down
14 changes: 6 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
cast_trt_tensor,
extend_attr_to_tuple,
get_trt_tensor,
to_numpy,
to_torch,
)
from torch_tensorrt.fx.converters.converter_utils import (
get_dyn_range,
Expand Down Expand Up @@ -45,7 +45,6 @@ def convNd(
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."

num_dims = len(input.shape) - 2

if is_conv1d:
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
input = impl.unsqueeze.unsqueeze(
Expand All @@ -54,8 +53,8 @@ def convNd(

# Process bias terms
if isinstance(bias, (torch.Tensor, np.ndarray)):
# Transform the bias constant into a Numpy array
bias = to_numpy(bias, dtype=input.dtype)
bias = to_torch(bias, dtype=input.dtype)
bias = get_trt_tensor(ctx, bias, f"{name}_bias")

elif isinstance(bias, TRTTensor):
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
Expand All @@ -74,12 +73,11 @@ def convNd(
ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1
)
elif isinstance(weight, (torch.Tensor, np.ndarray)):
# Transform the weight constant into a Numpy array
weight = to_numpy(weight, dtype=input.dtype)

weight = to_torch(weight, dtype=input.dtype)
# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
weight = np.expand_dims(weight, -1)
weight = torch.unsqueeze(weight, -1)
weight = get_trt_tensor(ctx, weight, f"{name}_weight")

else:
raise RuntimeError(
Expand Down
Loading
Loading