Skip to content

Commit e85844a

Browse files
authored
fix: torch-TRT runtime cache attribute + standard-TRT fast refit regression (#4225)
Signed-off-by: tejaswinp <tejaswinp@nvidia.com>
1 parent 0da10b2 commit e85844a

2 files changed

Lines changed: 31 additions & 13 deletions

File tree

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.export import ExportedProgram
1313
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1414
from torch_tensorrt._enums import dtype
15-
from torch_tensorrt._features import needs_refit
15+
from torch_tensorrt._features import ENABLED_FEATURES, needs_refit
1616
from torch_tensorrt._Input import Input
1717
from torch_tensorrt.dynamo import partitioning
1818
from torch_tensorrt.dynamo._exporter import inline_torch_modules
@@ -41,7 +41,6 @@
4141
TorchTensorRTModule,
4242
)
4343
from torch_tensorrt.dynamo.utils import (
44-
CPU_DEVICE,
4544
check_module_output,
4645
check_output_equal,
4746
get_model_device,
@@ -199,18 +198,24 @@ def _refit_single_trt_engine_with_gm(
199198
weight_dtype, weight.data_ptr(), torch.numel(weight)
200199
)
201200
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
202-
# Check completeness via two methods:
203-
# 1. get_missing_weights(): reports weights in connected engines
204-
# that were not set.
205-
# 2. Compare weights actually set vs all engine weights: catches
206-
# weights in independent engines that get_missing_weights() may not report.
201+
# get_missing_weights(): reports weights in connected engines
202+
# that were not set.
207203
missing_weights = refitter.get_missing_weights()
208-
unset_weights = {w for w in weight_list if w not in mapping}
209-
assert len(missing_weights) == 0 and len(unset_weights) == 0, (
210-
f"Fast refitting failed due to incomplete mapping"
211-
f" ({len(missing_weights)} missing,"
212-
f" {len(unset_weights)} unset)"
204+
assert len(missing_weights) == 0, (
205+
f"Fast refit failed: refitter.get_missing_weights() reports "
206+
f"{len(missing_weights)} of {len(weight_list)} engine weight(s) "
207+
f"were never set."
213208
)
209+
if ENABLED_FEATURES.tensorrt_rtx:
210+
# Compare weights actually set vs all engine weights: catches
211+
# weights in independent engines that get_missing_weights() may not report.
212+
unset_weights = {w for w in weight_list if w not in mapping}
213+
assert len(unset_weights) == 0, (
214+
f"Fast refit failed on TensorRT-RTX: {len(unset_weights)} of "
215+
f"{len(weight_list)} engine weight(s) had no entry in "
216+
f"weight_name_map. "
217+
f"Unset (showing up to 5): {sorted(unset_weights)[:5]}"
218+
)
214219

215220
else:
216221
mapping = construct_refit_mapping(new_gm, input_list, settings)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
multi_gpu_device_check,
2525
)
2626

27-
import tensorrt as trt
27+
# must import after torch_tensorrt to resolve tensorrt_rtx alias
28+
import tensorrt as trt # isort: skip
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -529,6 +530,12 @@ def _load_from_state_dict(
529530
self.output_names = state_dict[prefix + "output_names"]
530531
self.target_platform = state_dict[prefix + "platform"]
531532

533+
# Same rationale as __setstate__: ensure these exist before
534+
# setup_engine() so __del__ -> _save_runtime_cache() is safe even
535+
# if a future caller invokes this without __init__ having run.
536+
self.runtime_config = None
537+
self.runtime_cache = None
538+
532539
# Run multi-gpu device check to validate engine instantiation
533540
multi_gpu_device_check()
534541
self.setup_engine()
@@ -547,6 +554,12 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
547554
self.__dict__.update(state)
548555
# reset after unpickling, apbose: is this required though?
549556
self._nccl_comm = None
557+
# __getstate__ pops these; re-initialize before setup_engine() so
558+
# __del__ -> _save_runtime_cache() can always read them, including
559+
# on standard (non-RTX) TRT where setup_engine() does not call
560+
# _setup_runtime_config().
561+
self.runtime_config = None
562+
self.runtime_cache = None
550563
self.setup_engine()
551564

552565
def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:

0 commit comments

Comments
 (0)