Skip to content

Commit 1bcb34c

Browse files
committed
Added CPU offloading
1 parent 557bec3 commit 1bcb34c

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def compile(
421421
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422422
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423423
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
424+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
424425
**kwargs: Any,
425426
) -> torch.fx.GraphModule:
426427
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -550,15 +551,6 @@ def compile(
550551
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
551552
)
552553

553-
if (
554-
not immutable_weights
555-
and not refit_identical_engine_weights
556-
and enable_weight_streaming
557-
):
558-
raise ValueError(
559-
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
560-
)
561-
562554
if (
563555
"enable_cross_compile_for_windows" in kwargs.keys()
564556
and kwargs["enable_cross_compile_for_windows"]
@@ -674,6 +666,7 @@ def compile(
674666
"enable_weight_streaming": enable_weight_streaming,
675667
"tiling_optimization_level": tiling_optimization_level,
676668
"l2_limit_for_tiling": l2_limit_for_tiling,
669+
"offload_module_to_cpu": offload_module_to_cpu,
677670
}
678671

679672
settings = CompilationSettings(**compilation_options)
@@ -684,6 +677,9 @@ def compile(
684677
)
685678

686679
gm = exported_program.module()
680+
# TODO: Memory control prototyping. Under discussion
681+
if offload_module_to_cpu:
682+
exported_program.module().to("cpu")
687683
logger.debug("Input graph: " + str(gm.graph))
688684

689685
# Apply lowering on the graph module
@@ -820,6 +816,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
820816
trt_modules = {}
821817
# Iterate over all components that can be accelerated
822818
# Generate the corresponding TRT Module for those
819+
823820
for name, _ in partitioned_module.named_children():
824821
submodule = getattr(partitioned_module, name)
825822
# filter on the GraphModule

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52+
OFFLOAD_MODULE_TO_CPU = False
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28+
OFFLOAD_MODULE_TO_CPU,
2829
OPTIMIZATION_LEVEL,
2930
PASS_THROUGH_BUILD_FAILURES,
3031
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -488,14 +488,12 @@ def _save_weight_mapping(self) -> None:
488488
# Stage 1: Name mapping
489489
torch_device = to_torch_device(self.compilation_settings.device)
490490
gm_is_on_cuda = get_model_device(self.module).type == "cuda"
491-
if not gm_is_on_cuda:
492-
# If the model original position is on CPU, move it GPU
493-
sd = {
494-
k: v.reshape(-1).to(torch_device)
495-
for k, v in self.module.state_dict().items()
496-
}
497-
else:
498-
sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()}
491+
# If the model original position is on CPU, move it GPU
492+
sd = {
493+
k: v.reshape(-1).to(torch_device)
494+
for k, v in self.module.state_dict().items()
495+
}
496+
499497
weight_name_map: dict[str, Any] = {}
500498
np_map = {}
501499
constant_mapping = {}
@@ -733,7 +731,11 @@ def run(
733731
self._create_timing_cache(
734732
builder_config, self.compilation_settings.timing_cache_path
735733
)
736-
734+
# TODO: Memory control prototyping. Under discussion
735+
if self.compilation_settings.offload_module_to_cpu:
736+
del self.module
737+
gc.collect()
738+
torch.cuda.empty_cache()
737739
serialized_engine = self.builder.build_serialized_network(
738740
self.ctx.net, builder_config
739741
)

0 commit comments

Comments
 (0)