From 5235fcfaadd683b460cbb69055757f0dd6ac04e9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 18 Apr 2025 14:00:56 +0000 Subject: [PATCH 1/3] Enabled refit on Python 3.13 --- py/torch_tensorrt/_features.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 41 +++---------------- .../dynamo/conversion/_ConversionContext.py | 2 + .../dynamo/conversion/_TRTInterpreter.py | 14 ++----- .../dynamo/conversion/converter_utils.py | 1 + 5 files changed, 14 insertions(+), 46 deletions(-) diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 29ab495fec..320a9818b5 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -37,7 +37,7 @@ _TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path) _DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") _FX_FE_AVAIL = True -_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13") +_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.14") ENABLED_FEATURES = FeatureSet( _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index a1eda40b2d..4142673212 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -62,18 +62,6 @@ def construct_refit_mapping( Returns: Mapping from weight name in TensorRT to actual weight value in np.ndarray """ - MODULE_MAP = { - "SCALE": (trt.IScaleLayer, [("scale", "SCALE"), ("shift", "SHIFT")]), - "CONVOLUTION": ( - trt.IConvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "DECONVOLUTION": ( - trt.IDeconvolutionLayer, - [("kernel", "KERNEL"), ("bias", "BIAS")], - ), - "CONSTANT": (trt.IConstantLayer, [("weights", "CONSTANT")]), - } output_dtypes = infer_module_output_dtypes( module, @@ -81,7 +69,6 @@ def construct_refit_mapping( ) # Use Interpreter - weight_map = {} interpreter = TRTInterpreter( module, inputs, @@ -90,24 +77,8 @@ def construct_refit_mapping( compilation_settings=settings, ) interpreter._construct_trt_network_def() - net = interpreter.ctx.net - for i in range(net.num_layers): - layer = net[i] - layer_type: str = layer.type.name - if layer_type in MODULE_MAP: - # Cast the parent class to child class to access attributes - # For example: ILayer does not have ILayer.kernel/ILayer.bias - # So we cast it to IConvolutionLayer and access the attributes - layer.__class__ = MODULE_MAP[layer_type][0] - for weight_type, weight_name in MODULE_MAP[layer_type][1]: - weight = layer.__getattribute__(weight_type).copy() - weight_dtype = dtype.try_from(weight.dtype).to(trt.DataType) - weight_map[f"{layer.name} {weight_name}"] = ( - weight, - weight_dtype, - ) - return weight_map + return interpreter.ctx.mapping @needs_refit @@ -118,13 +89,12 @@ def construct_refit_mapping_from_weight_name_map( ) -> dict[Any, Any]: engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) - if sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue else: + trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) + torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( to_torch_device(settings.device) ) @@ -208,8 +178,9 @@ def _refit_single_trt_engine_with_gm( if layer_name not in mapping: raise AssertionError(f"{layer_name} is not found in weight mapping") # Use Numpy to create weights - weight, datatype = mapping[layer_name] - trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size) + weight = mapping[layer_name] + trt_dtype = dtype.try_from(weight.dtype).to(trt.DataType) + trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size) refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) refitted.add(layer_name) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 0dbdb2a8f4..94ebeb1dcd 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any, Dict from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.fx.types import TRTNetwork @@ -19,3 +20,4 @@ class ConversionContext: default_factory=CompilationSettings ) requires_output_allocator: bool = False + mapping: Dict[str, Any] = field(default_factory=dict) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index fde07bf1f5..908ceaec41 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -498,19 +498,15 @@ def _save_weight_mapping(self) -> None: for k, v in self.module.state_dict().items() } weight_name_map: dict[str, Any] = {} - np_map = {} - constant_mapping = {} + np_map = self.ctx.mapping + constant_mapping = {k: v for k, v in np_map.items() if v.size == 1} net = self.ctx.net for i in range(net.num_layers): layer = net[i] layer_type: str = layer.type.name if layer_type in MODULE_MAP: - layer.__class__ = MODULE_MAP[layer_type][0] # Name mapping for weight_type, weight_name, torch_attr in MODULE_MAP[layer_type][1]: - weight = layer.__getattribute__(weight_type).copy() - if weight.size == 0: - continue engine_weight_name = f"{layer.name} {weight_name}" # Infer the corresponding weight name(s) in state_dict sd_weight_name_list = ( @@ -538,17 +534,15 @@ def _save_weight_mapping(self) -> None: elif "bias" in suffix: sd_weight_name = f"{sd_weight_name}.bias" else: - # Save the constant weights for future fast refit sd_weight_name = f"{sd_weight_name}.unknown" - constant_mapping[engine_weight_name] = weight elif layer_type == "SCALE": # Batch norm needs all weights to calculate scale and shift sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] else: sd_weight_name = f"{sd_weight_name}.{torch_attr}" - weight_name_map[engine_weight_name] = sd_weight_name - np_map[engine_weight_name] = weight + if engine_weight_name in np_map: + weight_name_map[engine_weight_name] = sd_weight_name # Stage 2: Value mapping for engine_weight_name, sd_weight_name in weight_name_map.items(): diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 3edcbad2dd..0808db2ed8 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -367,6 +367,7 @@ def create_constant( else: numpy_value = torch_value.numpy() + ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) constant = ctx.net.add_constant( shape, numpy_value, From fd116f1e9a8c6fabc4f6d7d25016560a91862e8d Mon Sep 17 00:00:00 2001 From: cehongwang Date: Sat, 19 Apr 2025 12:10:18 +0000 Subject: [PATCH 2/3] Revised according to comments --- py/torch_tensorrt/_features.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 10 +++++----- .../dynamo/conversion/_ConversionContext.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 320a9818b5..bee0c3dbf0 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -37,7 +37,7 @@ _TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path) _DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") _FX_FE_AVAIL = True -_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.14") +_REFIT_AVAIL = True ENABLED_FEATURES = FeatureSet( _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 4142673212..7be7e0f16c 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -93,8 +93,8 @@ def construct_refit_mapping_from_weight_name_map( # If weights is not in sd, we can leave it unchanged continue else: - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( to_torch_device(settings.device) ) @@ -148,8 +148,8 @@ def _refit_single_trt_engine_with_gm( for constant_name, val in constant_mapping.items(): np_weight_type = val.dtype val_tensor = torch.from_numpy(val).cuda() - trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) - torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) constant_mapping_with_type[constant_name] = ( val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), trt_dtype, @@ -179,7 +179,7 @@ def _refit_single_trt_engine_with_gm( raise AssertionError(f"{layer_name} is not found in weight mapping") # Use Numpy to create weights weight = mapping[layer_name] - trt_dtype = dtype.try_from(weight.dtype).to(trt.DataType) + trt_dtype = dtype._from(weight.dtype).to(trt.DataType) trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size) refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) refitted.add(layer_name) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 94ebeb1dcd..fa5eacf7c7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from typing import Any, Dict +import numpy as np from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.fx.types import TRTNetwork @@ -20,4 +20,4 @@ class ConversionContext: default_factory=CompilationSettings ) requires_output_allocator: bool = False - mapping: Dict[str, Any] = field(default_factory=dict) + mapping: dict[str, np.array] = field(default_factory=dict) From 31ae9b57e40dda87943c958478fc21b010f5e8d9 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Fri, 25 Apr 2025 03:52:06 +0000 Subject: [PATCH 3/3] Skip the flashinfer test --- tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py index 6c10aafb7a..cf803c5ffa 100644 --- a/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py +++ b/tests/py/dynamo/automatic_plugin/test_flashinfer_rmsnorm.py @@ -1,6 +1,8 @@ import pytest flashinfer = pytest.importorskip("flashinfer") +import unittest + import torch import torch.nn as nn import torch_tensorrt @@ -28,6 +30,7 @@ def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tenso ) +@unittest.skip("Not Available") class TestAutomaticPlugin(DispatchTestCase): @parameterized.expand( [