|
12 | 12 | from torch.export import ExportedProgram |
13 | 13 | from torch.fx.experimental.proxy_tensor import unset_fake_temporarily |
14 | 14 | from torch_tensorrt._enums import dtype |
15 | | -from torch_tensorrt._features import needs_refit |
| 15 | +from torch_tensorrt._features import ENABLED_FEATURES, needs_refit |
16 | 16 | from torch_tensorrt._Input import Input |
17 | 17 | from torch_tensorrt.dynamo import partitioning |
18 | 18 | from torch_tensorrt.dynamo._exporter import inline_torch_modules |
|
41 | 41 | TorchTensorRTModule, |
42 | 42 | ) |
43 | 43 | from torch_tensorrt.dynamo.utils import ( |
44 | | - CPU_DEVICE, |
45 | 44 | check_module_output, |
46 | 45 | check_output_equal, |
47 | 46 | get_model_device, |
@@ -199,18 +198,24 @@ def _refit_single_trt_engine_with_gm( |
199 | 198 | weight_dtype, weight.data_ptr(), torch.numel(weight) |
200 | 199 | ) |
201 | 200 | refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location) |
202 | | - # Check completeness via two methods: |
203 | | - # 1. get_missing_weights(): reports weights in connected engines |
204 | | - # that were not set. |
205 | | - # 2. Compare weights actually set vs all engine weights: catches |
206 | | - # weights in independent engines that get_missing_weights() may not report. |
| 201 | + # get_missing_weights(): reports weights in connected engines |
| 202 | + # that were not set. |
207 | 203 | missing_weights = refitter.get_missing_weights() |
208 | | - unset_weights = {w for w in weight_list if w not in mapping} |
209 | | - assert len(missing_weights) == 0 and len(unset_weights) == 0, ( |
210 | | - f"Fast refitting failed due to incomplete mapping" |
211 | | - f" ({len(missing_weights)} missing," |
212 | | - f" {len(unset_weights)} unset)" |
| 204 | + assert len(missing_weights) == 0, ( |
| 205 | + f"Fast refit failed: refitter.get_missing_weights() reports " |
| 206 | + f"{len(missing_weights)} of {len(weight_list)} engine weight(s) " |
| 207 | + f"were never set." |
213 | 208 | ) |
| 209 | + if ENABLED_FEATURES.tensorrt_rtx: |
| 210 | + # Compare weights actually set vs all engine weights: catches |
| 211 | + # weights in independent engines that get_missing_weights() may not report. |
| 212 | + unset_weights = {w for w in weight_list if w not in mapping} |
| 213 | + assert len(unset_weights) == 0, ( |
| 214 | + f"Fast refit failed on TensorRT-RTX: {len(unset_weights)} of " |
| 215 | + f"{len(weight_list)} engine weight(s) had no entry in " |
| 216 | + f"weight_name_map. " |
| 217 | + f"Unset (showing up to 5): {sorted(unset_weights)[:5]}" |
| 218 | + ) |
214 | 219 |
|
215 | 220 | else: |
216 | 221 | mapping = construct_refit_mapping(new_gm, input_list, settings) |
|
0 commit comments