|
26 | 26 | from torch.fx.passes.shape_prop import TensorMetadata
|
27 | 27 | from torch.utils._python_dispatch import _disable_current_modes
|
28 | 28 | from torch_tensorrt._enums import dtype
|
| 29 | +from torch_tensorrt._features import needs_refit |
29 | 30 | from torch_tensorrt._Input import Input
|
30 | 31 | from torch_tensorrt.dynamo import _defaults
|
31 | 32 | from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
|
|
44 | 45 | get_trt_tensor,
|
45 | 46 | to_torch,
|
46 | 47 | )
|
47 |
| -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device |
| 48 | +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device |
48 | 49 | from torch_tensorrt.fx.observer import Observer
|
49 | 50 | from torch_tensorrt.logging import TRT_LOGGER
|
50 | 51 |
|
@@ -434,6 +435,7 @@ def check_weight_equal(
|
434 | 435 | except Exception:
|
435 | 436 | return torch.all(sd_weight == network_weight)
|
436 | 437 |
|
| 438 | + @needs_refit |
437 | 439 | def _save_weight_mapping(self) -> None:
|
438 | 440 | """
|
439 | 441 | Construct the weight name mapping from engine weight name to state_dict weight name.
|
@@ -491,15 +493,10 @@ def _save_weight_mapping(self) -> None:
|
491 | 493 | _LOGGER.info("Building weight name mapping...")
|
492 | 494 | # Stage 1: Name mapping
|
493 | 495 | torch_device = to_torch_device(self.compilation_settings.device)
|
494 |
| - gm_is_on_cuda = get_model_device(self.module).type == "cuda" |
495 |
| - if not gm_is_on_cuda: |
496 |
| - # If the model original position is on CPU, move it GPU |
497 |
| - sd = { |
498 |
| - k: v.reshape(-1).to(torch_device) |
499 |
| - for k, v in self.module.state_dict().items() |
500 |
| - } |
501 |
| - else: |
502 |
| - sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} |
| 496 | + sd = { |
| 497 | + k: v.reshape(-1).to(torch_device) |
| 498 | + for k, v in self.module.state_dict().items() |
| 499 | + } |
503 | 500 | weight_name_map: dict[str, Any] = {}
|
504 | 501 | np_map = {}
|
505 | 502 | constant_mapping = {}
|
@@ -583,6 +580,7 @@ def _save_weight_mapping(self) -> None:
|
583 | 580 | gc.collect()
|
584 | 581 | torch.cuda.empty_cache()
|
585 | 582 |
|
| 583 | + @needs_refit |
586 | 584 | def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
|
587 | 585 | # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
|
588 | 586 | # if not self.compilation_settings.strip_engine_weights:
|
@@ -610,6 +608,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
|
610 | 608 | ),
|
611 | 609 | )
|
612 | 610 |
|
| 611 | + @needs_refit |
613 | 612 | def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
|
614 | 613 | # query the cached TRT engine
|
615 | 614 | cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
|
@@ -720,7 +719,7 @@ def run(
|
720 | 719 | if self.compilation_settings.reuse_cached_engines:
|
721 | 720 | interpreter_result = self._pull_cached_engine(hash_val)
|
722 | 721 | if interpreter_result is not None: # hit the cache
|
723 |
| - return interpreter_result |
| 722 | + return interpreter_result # type: ignore[no-any-return] |
724 | 723 |
|
725 | 724 | self._construct_trt_network_def()
|
726 | 725 |
|
|
0 commit comments