Skip to content

Commit 4242743

Browse files
committed
Chagned CPU offload to default
1 parent f85820c commit 4242743

File tree

4 files changed

+8
-16
lines changed

4 files changed

+8
-16
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,6 @@ def compile(
421421
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422422
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423423
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
424-
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
425424
**kwargs: Any,
426425
) -> torch.fx.GraphModule:
427426
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -666,7 +665,6 @@ def compile(
666665
"enable_weight_streaming": enable_weight_streaming,
667666
"tiling_optimization_level": tiling_optimization_level,
668667
"l2_limit_for_tiling": l2_limit_for_tiling,
669-
"offload_module_to_cpu": offload_module_to_cpu,
670668
}
671669

672670
settings = CompilationSettings(**compilation_options)
@@ -677,9 +675,11 @@ def compile(
677675
)
678676

679677
gm = exported_program.module()
680-
# TODO: Memory control prototyping. Under discussion
681-
if offload_module_to_cpu:
682-
exported_program.module().to("cpu")
678+
# Move the weights in the state_dict to CPU
679+
exported_program.module().to("cpu")
680+
logger.info(
681+
"The model is moved to CPU during compilation. If you want to keep the model on GPU, call module.to('cuda') on the model after compilation."
682+
)
683683
logger.debug("Input graph: " + str(gm.graph))
684684

685685
# Apply lowering on the graph module

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52-
OFFLOAD_MODULE_TO_CPU = False
5352

5453

5554
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28-
OFFLOAD_MODULE_TO_CPU,
2928
OPTIMIZATION_LEVEL,
3029
PASS_THROUGH_BUILD_FAILURES,
3130
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -141,7 +140,6 @@ class CompilationSettings:
141140
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
142141
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
143142
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144-
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
145143

146144

147145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
get_trt_tensor,
4545
to_torch,
4646
)
47-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
47+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, delete_module, to_torch_device
4848
from torch_tensorrt.fx.observer import Observer
4949
from torch_tensorrt.logging import TRT_LOGGER
5050

@@ -491,8 +491,6 @@ def _save_weight_mapping(self) -> None:
491491
_LOGGER.info("Building weight name mapping...")
492492
# Stage 1: Name mapping
493493
torch_device = to_torch_device(self.compilation_settings.device)
494-
gm_is_on_cuda = get_model_device(self.module).type == "cuda"
495-
# If the model original position is on CPU, move it GPU
496494
sd = {
497495
k: v.reshape(-1).to(torch_device)
498496
for k, v in self.module.state_dict().items()
@@ -735,11 +733,8 @@ def run(
735733
self._create_timing_cache(
736734
builder_config, self.compilation_settings.timing_cache_path
737735
)
738-
# TODO: Memory control prototyping. Under discussion
739-
if self.compilation_settings.offload_module_to_cpu:
740-
del self.module
741-
gc.collect()
742-
torch.cuda.empty_cache()
736+
737+
delete_module(self.module)
743738
serialized_engine = self.builder.build_serialized_network(
744739
self.ctx.net, builder_config
745740
)

0 commit comments

Comments
 (0)