diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 96fc6daad2..c128e9cc82 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -9,6 +9,7 @@ import tensorrt as trt import torch from torch.export import ExportedProgram +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning @@ -144,71 +145,72 @@ def _refit_single_trt_engine_with_gm( Refit a TensorRT Engine in place """ - refitted = set() - torch_device = get_model_device(new_gm) - refitter = trt.Refitter(old_engine, TRT_LOGGER) - weight_list = refitter.get_all_weights() - - if weight_name_map: - # Get the refitting mapping - trt_wt_location = ( - trt.TensorLocation.DEVICE - if torch_device.type == "cuda" - else trt.TensorLocation.HOST - ) + with unset_fake_temporarily(): + refitted = set() + torch_device = get_model_device(new_gm) + refitter = trt.Refitter(old_engine, TRT_LOGGER) + weight_list = refitter.get_all_weights() + + if weight_name_map: + # Get the refitting mapping + trt_wt_location = ( + trt.TensorLocation.DEVICE + if torch_device.type == "cuda" + else trt.TensorLocation.HOST + ) - constant_mapping: dict[str, Any] = weight_name_map.pop( - "constant_mapping", {} - ) # type: ignore - mapping = construct_refit_mapping_from_weight_name_map( - weight_name_map, new_gm.state_dict() - ) - constant_mapping_with_type = {} - - for constant_name, val in constant_mapping.items(): - np_weight_type = val.dtype - val_tensor = torch.from_numpy(val).cuda() - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - constant_mapping_with_type[constant_name] = ( - val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), - trt_dtype, + constant_mapping: dict[str, Any] = weight_name_map.pop( + "constant_mapping", {} + ) # type: ignore + mapping = construct_refit_mapping_from_weight_name_map( + weight_name_map, new_gm.state_dict() ) + constant_mapping_with_type = {} + + for constant_name, val in constant_mapping.items(): + np_weight_type = val.dtype + val_tensor = torch.from_numpy(val).cuda() + trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) + torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + constant_mapping_with_type[constant_name] = ( + val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), + trt_dtype, + ) - mapping.update(constant_mapping_with_type) + mapping.update(constant_mapping_with_type) - for layer_name in weight_list: - if layer_name not in mapping: - logger.warning(f"{layer_name} is not found in weight mapping.") - continue - # Use Numpy to create weights - weight, weight_dtype = mapping[layer_name] - trt_wt_tensor = trt.Weights( - weight_dtype, weight.data_ptr(), torch.numel(weight) - ) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - assert ( - len(refitter.get_missing_weights()) == 0 - ), "Fast refitting failed due to incomplete mapping" + for layer_name in weight_list: + if layer_name not in mapping: + logger.warning(f"{layer_name} is not found in weight mapping.") + continue + # Use Numpy to create weights + weight, weight_dtype = mapping[layer_name] + trt_wt_tensor = trt.Weights( + weight_dtype, weight.data_ptr(), torch.numel(weight) + ) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + assert ( + len(refitter.get_missing_weights()) == 0 + ), "Fast refitting failed due to incomplete mapping" - else: - mapping = construct_refit_mapping(new_gm, input_list, settings) - trt_wt_location = trt.TensorLocation.HOST - for layer_name in weight_list: - if layer_name not in mapping: - raise AssertionError(f"{layer_name} is not found in weight mapping") - # Use Numpy to create weights - weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) - refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) - refitted.add(layer_name) - - if len(refitted) != len(weight_list): - logger.warning("Not all weights have been refitted!!!") - - if not refitter.refit_cuda_engine(): - logger.error("Error: failed to refit new weights.") - raise AssertionError("Refitting failed.") + else: + mapping = construct_refit_mapping(new_gm, input_list, settings) + trt_wt_location = trt.TensorLocation.HOST + for layer_name in weight_list: + if layer_name not in mapping: + raise AssertionError(f"{layer_name} is not found in weight mapping") + # Use Numpy to create weights + weight, datatype = mapping[layer_name] + trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) + refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) + refitted.add(layer_name) + + if len(refitted) != len(weight_list): + logger.warning("Not all weights have been refitted!!!") + + if not refitter.refit_cuda_engine(): + logger.error("Error: failed to refit new weights.") + raise AssertionError("Refitting failed.") def refit_module_weights( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 248e06bc3c..17f2fccbff 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -21,6 +21,7 @@ import tensorrt as trt import torch import torch.fx +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes @@ -41,6 +42,7 @@ get_node_io, get_node_name, get_trt_tensor, + to_torch, ) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device from torch_tensorrt.fx.observer import Observer @@ -408,12 +410,13 @@ def find_weight( np_map: the map from weight name to np values in INetworkDefinition state_dict: state of the graph module """ - network_weight = torch.from_numpy(np_map[weight_name]).to(device) - for sd_w_name, sd_weight in state_dict.items(): - if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): - del state_dict[sd_w_name] - return sd_w_name - return "" + with unset_fake_temporarily(): + network_weight = torch.from_numpy(np_map[weight_name]).to(device) + for sd_w_name, sd_weight in state_dict.items(): + if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device): + del state_dict[sd_w_name] + return sd_w_name + return "" @staticmethod def check_weight_equal( @@ -421,14 +424,15 @@ def check_weight_equal( network_weight: Union[torch.Tensor, np.ndarray], device: torch.device, ) -> Any: - if not isinstance(network_weight, torch.Tensor): - network_weight = torch.from_numpy(network_weight).to(device) - try: - return sd_weight.shape == network_weight.shape and torch.all( - torch.abs(sd_weight - network_weight) < 0.01 - ) - except Exception: - return torch.all(sd_weight == network_weight) + with unset_fake_temporarily(): + if not isinstance(network_weight, torch.Tensor): + network_weight = torch.from_numpy(network_weight).to(device) + try: + return sd_weight.shape == network_weight.shape and torch.all( + torch.abs(sd_weight - network_weight) < 0.01 + ) + except Exception: + return torch.all(sd_weight == network_weight) def _save_weight_mapping(self) -> None: """ @@ -887,9 +891,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: return converter(self.ctx, target, args, kwargs, self._cur_node_name) def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: - with _disable_current_modes(): - from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy - + with _disable_current_modes(), unset_fake_temporarily(): frozen_attr = self.fetch_attr(target) if isinstance(frozen_attr, torch.nn.Parameter): @@ -897,9 +899,7 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: else: constant_tensor = frozen_attr - network_constant = to_numpy(constant_tensor) - - return network_constant + return to_torch(constant_tensor) def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 62526080c4..bcb8495c67 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -9,6 +9,7 @@ import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt import _enums @@ -340,17 +341,47 @@ def create_constant( Returns: A TensorRT ITensor that represents the given value. """ - shape = (1,) - # Rank 0 constant is required in IFillLayer inputs. - if min_rank == 0: - shape = trt.Dims() - numpy_value = to_numpy(value, dtype) - constant = ctx.net.add_constant( - shape if isinstance(value, (int, float, bool)) else value.shape, - numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value, - ) - constant.name = name - return constant.get_output(0) + with unset_fake_temporarily(): + + torch_value = to_torch(value, dtype) + if torch_value.dtype == torch.float64: + raise ValueError( + "TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model." + ) + # Rank 0 constant is required in IFillLayer inputs. + if min_rank == 0 and isinstance(value, (int, float, bool)): + shape = trt.Dims() + elif list(torch_value.shape) == []: + shape = trt.Dims() + else: + shape = list(torch_value.shape) + + if torch_value is not None: + if torch_value.dtype == torch.bfloat16: + torch_value_fp32 = torch_value.to(torch.float32) + numpy_value = torch_value_fp32.numpy() + else: + numpy_value = torch_value.numpy() + + constant = ctx.net.add_constant( + shape, + numpy_value, + ) + constant.name = name + + if torch_value.dtype == torch.bfloat16: + return cast_trt_tensor( + ctx, + constant.get_output(0), + trt.DataType.BF16, + name + "_bf16_cast", + ) + + return constant.get_output(0) + else: + raise ValueError( + f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None." + ) def get_trt_tensor( @@ -564,6 +595,9 @@ def to_numpy( value = value.dequantize() elif value.dtype == torch.bfloat16: # TODO: Remove when numpy has a BF16 type + _LOGGER.warning( + "Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation", + ) value = value.to(torch.float) output = value.cpu().detach().contiguous().numpy() @@ -589,6 +623,53 @@ def to_numpy( ) +def to_torch( + value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, +) -> Optional[torch.Tensor]: + """ + Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU + Args: + value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]): + A PyTorch tensor, Numpy array, int, float, or bool + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. + Returns: + A PyTorch tensor or None, if the input was None. + """ + + cpu_device = torch.device("cpu") + torch_dtype = ( + _enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None + ) + + with unset_fake_temporarily(): + if value is None: + return None + + elif isinstance(value, torch.Tensor): + output = value.to(cpu_device).contiguous() + + elif isinstance(value, np.ndarray): + output = torch.from_numpy(value).to(cpu_device).contiguous() + + elif isinstance(value, int): + output = torch.tensor([value], device=cpu_device, dtype=torch.int32) + + elif isinstance(value, float): + output = torch.tensor([value], device=cpu_device, dtype=torch.float32) + + elif isinstance(value, bool): + output = torch.tensor([value], device=cpu_device, dtype=torch.bool) + + else: + raise AssertionError( + f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}" + ) + + return output.to(torch_dtype) if torch_dtype else output + + def flatten_dims( input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]], start_dim: int, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index 25419d7f60..f27fb13e97 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -13,7 +13,7 @@ cast_trt_tensor, extend_attr_to_tuple, get_trt_tensor, - to_numpy, + to_torch, ) from torch_tensorrt.fx.converters.converter_utils import ( get_dyn_range, @@ -45,7 +45,6 @@ def convNd( assert input.shape[1] != -1, "Channel dim can't be dynamic for convolution." num_dims = len(input.shape) - 2 - if is_conv1d: # Apply an unsqueeze operation to transform the conv1d problem into conv2d input = impl.unsqueeze.unsqueeze( @@ -54,8 +53,8 @@ def convNd( # Process bias terms if isinstance(bias, (torch.Tensor, np.ndarray)): - # Transform the bias constant into a Numpy array - bias = to_numpy(bias, dtype=input.dtype) + bias = to_torch(bias, dtype=input.dtype) + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif isinstance(bias, TRTTensor): bias = get_trt_tensor(ctx, bias, f"{name}_bias") @@ -74,12 +73,11 @@ def convNd( ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1 ) elif isinstance(weight, (torch.Tensor, np.ndarray)): - # Transform the weight constant into a Numpy array - weight = to_numpy(weight, dtype=input.dtype) - + weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the convolution is 1d if is_conv1d: - weight = np.expand_dims(weight, -1) + weight = torch.unsqueeze(weight, -1) + weight = get_trt_tensor(ctx, weight, f"{name}_weight") else: raise RuntimeError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index d19a92e646..629cecf5db 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -6,13 +6,12 @@ import tensorrt as trt import torch from torch.fx.node import Target - from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( extend_attr_to_tuple, get_trt_tensor, - to_numpy, + to_torch, ) from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, @@ -53,7 +52,8 @@ def deconvNd( # Process bias terms if isinstance(bias, (torch.Tensor, np.ndarray)): # Transform the bias constant into a Numpy array - bias = to_numpy(bias) + bias = to_torch(bias, dtype=input.dtype) + bias = get_trt_tensor(ctx, bias, f"{name}_bias") elif isinstance(bias, TRTTensor): bias = get_trt_tensor(ctx, bias, f"{name}_bias") @@ -73,12 +73,12 @@ def deconvNd( ) elif isinstance(weight, (torch.Tensor, np.ndarray)): - # Transform the weight constant into a Numpy array - weight = to_numpy(weight) - + weight = to_torch(weight, dtype=input.dtype) # Append new dimension (unsqueeze) if the deconvolution is 1d if is_deconv1d: - weight = np.expand_dims(weight, axis=-1) + weight = torch.unsqueeze(weight, -1) + + weight = get_trt_tensor(ctx, weight, f"{name}_weight") else: raise RuntimeError( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py index b97840cd09..e472ed3092 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/quantize.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/quantize.py @@ -1,11 +1,13 @@ -from typing import Optional +from typing import Optional, Union import numpy as np import tensorrt as trt +import torch +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor, to_torch from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -16,7 +18,7 @@ def quantize( source_ir: Optional[SourceIR], name: str, input_tensor: TRTTensor, - amax: np.ndarray, + amax: Union[np.ndarray, torch.Tensor], num_bits: int, exponent_bits: int, ) -> TRTTensor: @@ -24,40 +26,44 @@ def quantize( Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based on the output_type set and dequantizes them back. """ - if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( - trt.float32, - trt.float16, - ): - raise ValueError( - f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" - ) - if num_bits != 8 or exponent_bits not in (0, 4): - raise ValueError( - f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}" - ) - if num_bits == 8 and exponent_bits == 0: - max_bound = 127 - elif num_bits == 8 and exponent_bits == 4: - max_bound = 448 - scale = np.divide(amax, max_bound) - scale = get_trt_tensor(ctx, scale, name + "_scale") - # Add Q node - quantize_layer = ctx.net.add_quantize(input_tensor, scale) - if num_bits == 8 and exponent_bits == 0: - quantize_layer.set_output_type(0, trt.DataType.INT8) - elif num_bits == 8 and exponent_bits == 4: - quantize_layer.set_output_type(0, trt.DataType.FP8) - set_layer_name(quantize_layer, target, name + "_quantize", source_ir) - q_output = quantize_layer.get_output(0) - # Add DQ node - dequantize_layer = ctx.net.add_dequantize(q_output, scale) - set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) - if num_bits == 8 and exponent_bits == 0: - dequantize_layer.precision = trt.DataType.INT8 - elif num_bits == 8 and exponent_bits == 4: - # Set DQ layer precision to FP8 - dequantize_layer.precision = trt.DataType.FP8 - dq_output = dequantize_layer.get_output(0) + with unset_fake_temporarily(): + if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in ( + trt.float32, + trt.float16, + ): + raise ValueError( + f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16" + ) + if num_bits != 8 or exponent_bits not in (0, 4): + raise ValueError( + f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}" + ) + if num_bits == 8 and exponent_bits == 0: + max_bound = 127 + elif num_bits == 8 and exponent_bits == 4: + max_bound = 448 - return dq_output + amax = to_torch(amax, None) + scale = torch.divide(amax, max_bound) + scale = get_trt_tensor(ctx, scale, name + "_scale") + # Add Q node + quantize_layer = ctx.net.add_quantize(input_tensor, scale) + if num_bits == 8 and exponent_bits == 0: + quantize_layer.set_output_type(0, trt.DataType.INT8) + elif num_bits == 8 and exponent_bits == 4: + quantize_layer.set_output_type(0, trt.DataType.FP8) + + set_layer_name(quantize_layer, target, name + "_quantize", source_ir) + q_output = quantize_layer.get_output(0) + # Add DQ node + dequantize_layer = ctx.net.add_dequantize(q_output, scale) + set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir) + if num_bits == 8 and exponent_bits == 0: + dequantize_layer.precision = trt.DataType.INT8 + elif num_bits == 8 and exponent_bits == 4: + # Set DQ layer precision to FP8 + dequantize_layer.precision = trt.DataType.FP8 + dq_output = dequantize_layer.get_output(0) + + return dq_output diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b6f986711a..6314baa5ec 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -182,3 +182,94 @@ def test_resnet18_half(ir): # Clean up model env torch._dynamo.reset() + + +@pytest.mark.unit +def test_bf16_model(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + model = MyModule().eval().cuda().to(torch.bfloat16) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.bfloat16) + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.bfloat16, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float32}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"BF16 model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + +@pytest.mark.unit +def test_bf16_fallback_model(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1, stride=1, bias=True) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(16, 16, 3, padding=1, stride=1, bias=True) + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + out = self.conv2(out) + return out + + model = MyModule().eval().cuda().to(torch.bfloat16) + input = torch.randn((1, 3, 224, 224)).to("cuda").to(torch.bfloat16) + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.bfloat16, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float32}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + "torch_executed_ops": {"torch.ops.aten.relu.default"}, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"BF16 fallback model TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 469ed569d1..6f96e259b0 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -249,6 +249,7 @@ def calibrate_loop(model): @unittest.skipIf( platform.system() != "Linux" + or torch.cuda.get_device_capability() < (8, 9) or not importlib.util.find_spec("modelopt") or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"), "modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux", @@ -257,7 +258,6 @@ def calibrate_loop(model): def test_base_int8(ir): import modelopt.torch.quantization as mtq from modelopt.torch.quantization.utils import export_torch_mode - from torch.export._trace import _export class SimpleNetwork(torch.nn.Module): def __init__(self): @@ -285,7 +285,7 @@ def calibrate_loop(model): with torch.no_grad(): with export_torch_mode(): - exp_program = _export(model, (input_tensor,)) + exp_program = torch.export.export(model, (input_tensor,)) trt_model = torchtrt.dynamo.compile( exp_program, inputs=[input_tensor], @@ -294,6 +294,7 @@ def calibrate_loop(model): debug=True, cache_built_engines=False, reuse_cached_engines=False, + truncate_double=True, ) outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)