Skip to content

Commit 97273d4

Browse files
committed
fix: Change the translational layer from numpy to torch during conversion to handle additional data types (#3445)
1 parent 1167721 commit 97273d4

File tree

8 files changed

+325
-146
lines changed

8 files changed

+325
-146
lines changed

py/torch_tensorrt/dynamo/_refit.py

+62-60
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tensorrt as trt
1010
import torch
1111
from torch.export import ExportedProgram
12+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1213
from torch_tensorrt._enums import dtype
1314
from torch_tensorrt._Input import Input
1415
from torch_tensorrt.dynamo import partitioning
@@ -144,71 +145,72 @@ def _refit_single_trt_engine_with_gm(
144145
Refit a TensorRT Engine in place
145146
"""
146147

147-
refitted = set()
148-
torch_device = get_model_device(new_gm)
149-
refitter = trt.Refitter(old_engine, TRT_LOGGER)
150-
weight_list = refitter.get_all_weights()
151-
152-
if weight_name_map:
153-
# Get the refitting mapping
154-
trt_wt_location = (
155-
trt.TensorLocation.DEVICE
156-
if torch_device.type == "cuda"
157-
else trt.TensorLocation.HOST
158-
)
148+
with unset_fake_temporarily():
149+
refitted = set()
150+
torch_device = get_model_device(new_gm)
151+
refitter = trt.Refitter(old_engine, TRT_LOGGER)
152+
weight_list = refitter.get_all_weights()
153+
154+
if weight_name_map:
155+
# Get the refitting mapping
156+
trt_wt_location = (
157+
trt.TensorLocation.DEVICE
158+
if torch_device.type == "cuda"
159+
else trt.TensorLocation.HOST
160+
)
159161

160-
constant_mapping: dict[str, Any] = weight_name_map.pop(
161-
"constant_mapping", {}
162-
) # type: ignore
163-
mapping = construct_refit_mapping_from_weight_name_map(
164-
weight_name_map, new_gm.state_dict()
165-
)
166-
constant_mapping_with_type = {}
167-
168-
for constant_name, val in constant_mapping.items():
169-
np_weight_type = val.dtype
170-
val_tensor = torch.from_numpy(val).cuda()
171-
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
172-
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
173-
constant_mapping_with_type[constant_name] = (
174-
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
175-
trt_dtype,
162+
constant_mapping: dict[str, Any] = weight_name_map.pop(
163+
"constant_mapping", {}
164+
) # type: ignore
165+
mapping = construct_refit_mapping_from_weight_name_map(
166+
weight_name_map, new_gm.state_dict()
176167
)
168+
constant_mapping_with_type = {}
169+
170+
for constant_name, val in constant_mapping.items():
171+
np_weight_type = val.dtype
172+
val_tensor = torch.from_numpy(val).cuda()
173+
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
174+
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
175+
constant_mapping_with_type[constant_name] = (
176+
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
177+
trt_dtype,
178+
)
177179

178-
mapping.update(constant_mapping_with_type)
180+
mapping.update(constant_mapping_with_type)
179181

180-
for layer_name in weight_list:
181-
if layer_name not in mapping:
182-
logger.warning(f"{layer_name} is not found in weight mapping.")
183-
continue
184-
# Use Numpy to create weights
185-
weight, weight_dtype = mapping[layer_name]
186-
trt_wt_tensor = trt.Weights(
187-
weight_dtype, weight.data_ptr(), torch.numel(weight)
188-
)
189-
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
190-
assert (
191-
len(refitter.get_missing_weights()) == 0
192-
), "Fast refitting failed due to incomplete mapping"
182+
for layer_name in weight_list:
183+
if layer_name not in mapping:
184+
logger.warning(f"{layer_name} is not found in weight mapping.")
185+
continue
186+
# Use Numpy to create weights
187+
weight, weight_dtype = mapping[layer_name]
188+
trt_wt_tensor = trt.Weights(
189+
weight_dtype, weight.data_ptr(), torch.numel(weight)
190+
)
191+
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
192+
assert (
193+
len(refitter.get_missing_weights()) == 0
194+
), "Fast refitting failed due to incomplete mapping"
193195

194-
else:
195-
mapping = construct_refit_mapping(new_gm, input_list, settings)
196-
trt_wt_location = trt.TensorLocation.HOST
197-
for layer_name in weight_list:
198-
if layer_name not in mapping:
199-
raise AssertionError(f"{layer_name} is not found in weight mapping")
200-
# Use Numpy to create weights
201-
weight, datatype = mapping[layer_name]
202-
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
203-
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
204-
refitted.add(layer_name)
205-
206-
if len(refitted) != len(weight_list):
207-
logger.warning("Not all weights have been refitted!!!")
208-
209-
if not refitter.refit_cuda_engine():
210-
logger.error("Error: failed to refit new weights.")
211-
raise AssertionError("Refitting failed.")
196+
else:
197+
mapping = construct_refit_mapping(new_gm, input_list, settings)
198+
trt_wt_location = trt.TensorLocation.HOST
199+
for layer_name in weight_list:
200+
if layer_name not in mapping:
201+
raise AssertionError(f"{layer_name} is not found in weight mapping")
202+
# Use Numpy to create weights
203+
weight, datatype = mapping[layer_name]
204+
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
205+
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
206+
refitted.add(layer_name)
207+
208+
if len(refitted) != len(weight_list):
209+
logger.warning("Not all weights have been refitted!!!")
210+
211+
if not refitter.refit_cuda_engine():
212+
logger.error("Error: failed to refit new weights.")
213+
raise AssertionError("Refitting failed.")
212214

213215

214216
def refit_module_weights(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tensorrt as trt
2222
import torch
2323
import torch.fx
24+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
2425
from torch.fx.node import _get_qualified_name
2526
from torch.fx.passes.shape_prop import TensorMetadata
2627
from torch.utils._python_dispatch import _disable_current_modes
@@ -41,6 +42,7 @@
4142
get_node_io,
4243
get_node_name,
4344
get_trt_tensor,
45+
to_torch,
4446
)
4547
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
4648
from torch_tensorrt.fx.observer import Observer
@@ -408,27 +410,29 @@ def find_weight(
408410
np_map: the map from weight name to np values in INetworkDefinition
409411
state_dict: state of the graph module
410412
"""
411-
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
412-
for sd_w_name, sd_weight in state_dict.items():
413-
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
414-
del state_dict[sd_w_name]
415-
return sd_w_name
416-
return ""
413+
with unset_fake_temporarily():
414+
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
415+
for sd_w_name, sd_weight in state_dict.items():
416+
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
417+
del state_dict[sd_w_name]
418+
return sd_w_name
419+
return ""
417420

418421
@staticmethod
419422
def check_weight_equal(
420423
sd_weight: torch.tensor,
421424
network_weight: Union[torch.Tensor, np.ndarray],
422425
device: torch.device,
423426
) -> Any:
424-
if not isinstance(network_weight, torch.Tensor):
425-
network_weight = torch.from_numpy(network_weight).to(device)
426-
try:
427-
return sd_weight.shape == network_weight.shape and torch.all(
428-
torch.abs(sd_weight - network_weight) < 0.01
429-
)
430-
except Exception:
431-
return torch.all(sd_weight == network_weight)
427+
with unset_fake_temporarily():
428+
if not isinstance(network_weight, torch.Tensor):
429+
network_weight = torch.from_numpy(network_weight).to(device)
430+
try:
431+
return sd_weight.shape == network_weight.shape and torch.all(
432+
torch.abs(sd_weight - network_weight) < 0.01
433+
)
434+
except Exception:
435+
return torch.all(sd_weight == network_weight)
432436

433437
def _save_weight_mapping(self) -> None:
434438
"""
@@ -887,19 +891,15 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
887891
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
888892

889893
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
890-
with _disable_current_modes():
891-
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
892-
894+
with _disable_current_modes(), unset_fake_temporarily():
893895
frozen_attr = self.fetch_attr(target)
894896

895897
if isinstance(frozen_attr, torch.nn.Parameter):
896898
constant_tensor = frozen_attr.data
897899
else:
898900
constant_tensor = frozen_attr
899901

900-
network_constant = to_numpy(constant_tensor)
901-
902-
return network_constant
902+
return to_torch(constant_tensor)
903903

904904
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
905905
assert isinstance(target, str)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+92-11
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tensorrt as trt
1010
import torch
1111
import torch_tensorrt.dynamo.conversion.impl as impl
12+
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1213
from torch.fx.node import Argument, Target
1314
from torch.fx.passes.shape_prop import TensorMetadata
1415
from torch_tensorrt import _enums
@@ -340,17 +341,47 @@ def create_constant(
340341
Returns:
341342
A TensorRT ITensor that represents the given value.
342343
"""
343-
shape = (1,)
344-
# Rank 0 constant is required in IFillLayer inputs.
345-
if min_rank == 0:
346-
shape = trt.Dims()
347-
numpy_value = to_numpy(value, dtype)
348-
constant = ctx.net.add_constant(
349-
shape if isinstance(value, (int, float, bool)) else value.shape,
350-
numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
351-
)
352-
constant.name = name
353-
return constant.get_output(0)
344+
with unset_fake_temporarily():
345+
346+
torch_value = to_torch(value, dtype)
347+
if torch_value.dtype == torch.float64:
348+
raise ValueError(
349+
"TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model."
350+
)
351+
# Rank 0 constant is required in IFillLayer inputs.
352+
if min_rank == 0 and isinstance(value, (int, float, bool)):
353+
shape = trt.Dims()
354+
elif list(torch_value.shape) == []:
355+
shape = trt.Dims()
356+
else:
357+
shape = list(torch_value.shape)
358+
359+
if torch_value is not None:
360+
if torch_value.dtype == torch.bfloat16:
361+
torch_value_fp32 = torch_value.to(torch.float32)
362+
numpy_value = torch_value_fp32.numpy()
363+
else:
364+
numpy_value = torch_value.numpy()
365+
366+
constant = ctx.net.add_constant(
367+
shape,
368+
numpy_value,
369+
)
370+
constant.name = name
371+
372+
if torch_value.dtype == torch.bfloat16:
373+
return cast_trt_tensor(
374+
ctx,
375+
constant.get_output(0),
376+
trt.DataType.BF16,
377+
name + "_bf16_cast",
378+
)
379+
380+
return constant.get_output(0)
381+
else:
382+
raise ValueError(
383+
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
384+
)
354385

355386

356387
def get_trt_tensor(
@@ -564,6 +595,9 @@ def to_numpy(
564595
value = value.dequantize()
565596
elif value.dtype == torch.bfloat16:
566597
# TODO: Remove when numpy has a BF16 type
598+
_LOGGER.warning(
599+
"Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation",
600+
)
567601
value = value.to(torch.float)
568602

569603
output = value.cpu().detach().contiguous().numpy()
@@ -589,6 +623,53 @@ def to_numpy(
589623
)
590624

591625

626+
def to_torch(
627+
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
628+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
629+
) -> Optional[torch.Tensor]:
630+
"""
631+
Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
632+
Args:
633+
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
634+
A PyTorch tensor, Numpy array, int, float, or bool
635+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
636+
If a dtype is given, we will convert the type of the given `value` to this dtype.
637+
Returns:
638+
A PyTorch tensor or None, if the input was None.
639+
"""
640+
641+
cpu_device = torch.device("cpu")
642+
torch_dtype = (
643+
_enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None
644+
)
645+
646+
with unset_fake_temporarily():
647+
if value is None:
648+
return None
649+
650+
elif isinstance(value, torch.Tensor):
651+
output = value.to(cpu_device).contiguous()
652+
653+
elif isinstance(value, np.ndarray):
654+
output = torch.from_numpy(value).to(cpu_device).contiguous()
655+
656+
elif isinstance(value, int):
657+
output = torch.tensor([value], device=cpu_device, dtype=torch.int32)
658+
659+
elif isinstance(value, float):
660+
output = torch.tensor([value], device=cpu_device, dtype=torch.float32)
661+
662+
elif isinstance(value, bool):
663+
output = torch.tensor([value], device=cpu_device, dtype=torch.bool)
664+
665+
else:
666+
raise AssertionError(
667+
f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}"
668+
)
669+
670+
return output.to(torch_dtype) if torch_dtype else output
671+
672+
592673
def flatten_dims(
593674
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
594675
start_dim: int,

py/torch_tensorrt/dynamo/conversion/impl/conv.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
cast_trt_tensor,
1414
extend_attr_to_tuple,
1515
get_trt_tensor,
16-
to_numpy,
16+
to_torch,
1717
)
1818
from torch_tensorrt.fx.converters.converter_utils import (
1919
get_dyn_range,
@@ -45,7 +45,6 @@ def convNd(
4545
assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution."
4646

4747
num_dims = len(input.shape) - 2
48-
4948
if is_conv1d:
5049
# Apply an unsqueeze operation to transform the conv1d problem into conv2d
5150
input = impl.unsqueeze.unsqueeze(
@@ -54,8 +53,8 @@ def convNd(
5453

5554
# Process bias terms
5655
if isinstance(bias, (torch.Tensor, np.ndarray)):
57-
# Transform the bias constant into a Numpy array
58-
bias = to_numpy(bias, dtype=input.dtype)
56+
bias = to_torch(bias, dtype=input.dtype)
57+
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
5958

6059
elif isinstance(bias, TRTTensor):
6160
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
@@ -74,12 +73,11 @@ def convNd(
7473
ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1
7574
)
7675
elif isinstance(weight, (torch.Tensor, np.ndarray)):
77-
# Transform the weight constant into a Numpy array
78-
weight = to_numpy(weight, dtype=input.dtype)
79-
76+
weight = to_torch(weight, dtype=input.dtype)
8077
# Append new dimension (unsqueeze) if the convolution is 1d
8178
if is_conv1d:
82-
weight = np.expand_dims(weight, -1)
79+
weight = torch.unsqueeze(weight, -1)
80+
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
8381

8482
else:
8583
raise RuntimeError(

0 commit comments

Comments
 (0)