Skip to content

Added CPU offloading #3452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
pre_export_lowering,
)
from torch_tensorrt.dynamo.utils import (
CPU_DEVICE,
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
Expand Down Expand Up @@ -421,6 +422,7 @@ def compile(
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -550,15 +552,6 @@ def compile(
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
)

if (
not immutable_weights
and not refit_identical_engine_weights
and enable_weight_streaming
):
raise ValueError(
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
)

if (
"enable_cross_compile_for_windows" in kwargs.keys()
and kwargs["enable_cross_compile_for_windows"]
Expand Down Expand Up @@ -674,6 +667,7 @@ def compile(
"enable_weight_streaming": enable_weight_streaming,
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
}

settings = CompilationSettings(**compilation_options)
Expand All @@ -690,6 +684,18 @@ def compile(
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
exported_program.module().to(CPU_DEVICE)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory / 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_memory // 2

logger.warning(
"The remaining GPU memory is not enough to compile the model. This may lead to an OOM error. Consider setting offload_module_to_cpu=True."
)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
Expand Down Expand Up @@ -820,6 +826,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand All @@ -833,6 +840,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
str(name),
str(submodule.graph),
)
submodule.to(to_torch_device(settings.device))
continue

if name not in submodule_node_dict:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
TILING_OPTIMIZATION_LEVEL = "none"
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test the whole test suite by enabling this to True. I think we discussed the default to be True here. Since this would be a breaking change, we shall mention this in release notes.



def default_device() -> Device:
Expand Down
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def construct_refit_mapping(


def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
weight_name_map: dict[Any, Any],
state_dict: dict[Any, Any],
settings: CompilationSettings,
) -> dict[Any, Any]:
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
Expand All @@ -120,7 +122,9 @@ def construct_refit_mapping_from_weight_name_map(
# If weights is not in sd, we can leave it unchanged
continue
else:
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name]
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
to_torch_device(settings.device)
)

engine_weight_map[engine_weight_name] = (
engine_weight_map[engine_weight_name]
Expand Down Expand Up @@ -163,7 +167,7 @@ def _refit_single_trt_engine_with_gm(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
weight_name_map, new_gm.state_dict(), settings
)
constant_mapping_with_type = {}

Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OFFLOAD_MODULE_TO_CPU,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
REFIT_IDENTICAL_ENGINE_WEIGHTS,
Expand Down Expand Up @@ -140,6 +141,7 @@ class CompilationSettings:
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
19 changes: 8 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_trt_tensor,
to_torch,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, delete_module, to_torch_device
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

Expand Down Expand Up @@ -491,15 +491,11 @@ def _save_weight_mapping(self) -> None:
_LOGGER.info("Building weight name mapping...")
# Stage 1: Name mapping
torch_device = to_torch_device(self.compilation_settings.device)
gm_is_on_cuda = get_model_device(self.module).type == "cuda"
if not gm_is_on_cuda:
# If the model original position is on CPU, move it GPU
sd = {
k: v.reshape(-1).to(torch_device)
for k, v in self.module.state_dict().items()
}
else:
sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()}
sd = {
k: v.reshape(-1).to(torch_device)
for k, v in self.module.state_dict().items()
}

weight_name_map: dict[str, Any] = {}
np_map = {}
constant_mapping = {}
Expand Down Expand Up @@ -737,7 +733,8 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)

if self.compilation_settings.offload_module_to_cpu:
delete_module(self.module)
serialized_engine = self.builder.build_serialized_network(
self.ctx.net, builder_config
)
Expand Down
Loading