diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 29ab495fec..bee0c3dbf0 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -37,7 +37,7 @@ _TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path) _DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") _FX_FE_AVAIL = True -_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13") +_REFIT_AVAIL = True ENABLED_FEATURES = FeatureSet( _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index a1eda40b2d..7be7e0f16c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -62,18 +62,6 @@ def construct_refit_mapping( Returns: Mapping from weight name in TensorRT to actual weight value in np.ndarray """ - MODULE_MAP = { - "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), - "CONVOLUTION": ( - trt.IConvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "DECONVOLUTION": ( - trt.IDeconvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), - } output_dtypes = infer_module_output_dtypes( module, @@ -81,7 +69,6 @@ def construct_refit_mapping( ) # Use Interpreter - weight_map = {} interpreter = TRTInterpreter( module, inputs, @@ -90,24 +77,8 @@ def construct_refit_mapping( compilation_settings=settings, ) interpreter._construct_trt_network_def() - net = interpreter.ctx.net - for i in range(net.num_layers): - layer = net[i] - layer_type: str = layer.type.name - if layer_type in MODULE_MAP: - # Cast the parent class to child class to access attributes - # For example: ILayer does not have ILayer.kernel/ILayer.bias - # So we cast it to IConvolutionLayer and access the attributes - layer.__class__ = MODULE_MAP[layer_type][0] - for weight_type, weight_name in MODULE_MAP[layer_type][1]: - weight = layer.__getattribute__(weight_type).copy() - weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType) - weight_map[f"{layer.name} {weight_name}"] = ( - weight, - weight_dtype, - ) - return weight_map + return interpreter.ctx.mapping @needs_refit @@ -118,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map( ) -> dict[Any, Any]: engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - if sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue else: + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( to_torch_device(settings.device) ) @@ -178,8 +148,8 @@ def _refit_single_trt_engine_with_gm( for constant_name, val in constant_mapping.items(): np_weight_type = val.dtype val_tensor = torch.from_numpy(val).cuda() - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) constant_mapping_with_type[constant_name] = ( val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), trt_dtype, @@ -208,8 +178,9 @@ def _refit_single_trt_engine_with_gm( if layer_name not in mapping: raise AssertionError(f"{layer_name} is not found in weight mapping") # Use Numpy to create weights - weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) + weight = mapping[layer_name] + trt_dtype = dtype._from(weight.dtype).to(trt.DataType) + trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size) refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) refitted.add(layer_name) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 0dbdb2a8f4..fa5eacf7c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field +import numpy as np from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.fx.types import TRTNetwork @@ -19,3 +20,4 @@ class ConversionContext: default_factory=CompilationSettings ) requires_output_allocator: bool = False + mapping: dict[str, np.array] = field(default_factory=dict) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index fde07bf1f5..908ceaec41 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -498,19 +498,15 @@ def _save_weight_mapping(self) -> None: for k, v in self.module.state_dict().items() } weight_name_map: dict[str, Any] = {} - np_map = {} - constant_mapping = {} + np_map = self.ctx.mapping + constant_mapping = {k: v for k, v in np_map.items() if v.size == 1} net = self.ctx.net for i in range(net.num_layers): layer = net[i] layer_type: str = layer.type.name if layer_type in MODULE_MAP: - layer.__class__ = MODULE_MAP[layer_type][0] # Name mapping for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]: - weight = layer.__getattribute__(weight_type).copy() - if weight.size == 0: - continue engine_weight_name = f"{layer.name} {weight_name}" # Infer the corresponding weight name(s) in state_dict sd_weight_name_list = ( @@ -538,17 +534,15 @@ def _save_weight_mapping(self) -> None: elif "bias" in suffix: sd_weight_name = f"{sd_weight_name}.bias" else: - # Save the constant weights for future fast refit sd_weight_name = f"{sd_weight_name}.unknown" - constant_mapping[engine_weight_name] = weight elif layer_type == "SCALE": # Batch norm needs all weights to calculate scale and shift sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] else: sd_weight_name = f"{sd_weight_name}.{torch_attr}" - weight_name_map[engine_weight_name] = sd_weight_name - np_map[engine_weight_name] = weight + if engine_weight_name in np_map: + weight_name_map[engine_weight_name] = sd_weight_name # Stage 2: Value mapping for engine_weight_name, sd_weight_name in weight_name_map.items(): diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 3edcbad2dd..0808db2ed8 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -367,6 +367,7 @@ def create_constant( else: numpy_value = torch_value.numpy() + ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) constant = ctx.net.add_constant( shape, numpy_value, diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py index 6c10aafb7a..cf803c5ffa 100644 --- a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -1,6 +1,8 @@ import pytest flashinfer = pytest.importorskip("flashinfer") +import unittest + import torch import torch.nn as nn import torch_tensorrt @@ -28,6 +30,7 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso ) +@unittest.skip("Not Available") class TestAutomaticPlugin(DispatchTestCase): @parameterized.expand( [