Skip to content

Commit c2ae18c

Browse files
committed
Fixed memory overhead and enabled Flux with Mutable Module
1 parent a0552c5 commit c2ae18c

File tree

6 files changed

+45
-84
lines changed

6 files changed

+45
-84
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def compile(
421421
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422422
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
423423
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
424+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
424425
**kwargs: Any,
425426
) -> torch.fx.GraphModule:
426427
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -674,6 +675,7 @@ def compile(
674675
"enable_weight_streaming": enable_weight_streaming,
675676
"tiling_optimization_level": tiling_optimization_level,
676677
"l2_limit_for_tiling": l2_limit_for_tiling,
678+
"offload_module_to_cpu": offload_module_to_cpu,
677679
}
678680

679681
settings = CompilationSettings(**compilation_options)
@@ -685,7 +687,8 @@ def compile(
685687

686688
gm = exported_program.module()
687689
# TODO: Memory control prototyping. Under discussion
688-
exported_program.module().to("cpu")
690+
if offload_module_to_cpu:
691+
exported_program.module().to("cpu")
689692
logger.debug("Input graph: " + str(gm.graph))
690693

691694
# Apply lowering on the graph module

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52+
OFFLOAD_MODULE_TO_CPU = True
5253

5354

5455
def default_device() -> Device:

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import copy
5+
import gc
56
import logging
67
from typing import Any, List, Optional, Sequence, Tuple
78

@@ -307,6 +308,10 @@ def refit_module_weights(
307308
get_decompositions(settings.enable_experimental_decompositions)
308309
)
309310
new_gm = new_weight_module.module()
311+
# TODO: Memory control prototyping. Under discussion
312+
if settings.offload_module_to_cpu:
313+
new_weight_module.module().to("cpu")
314+
310315
logger.debug("Input graph: " + str(new_gm.graph))
311316
# Apply lowering on the graph module
312317

@@ -462,12 +467,21 @@ def refit_module_weights(
462467
settings=settings,
463468
weight_name_map=None,
464469
)
470+
# TODO: Memory control prototyping. Under discussion
471+
if settings.offload_module_to_cpu:
472+
del new_submodule
473+
gc.collect()
474+
torch.cuda.empty_cache()
465475

466476
# clear EXCLUDE_WEIGHTS flag
467477
serialization_config = engine.create_serialization_config()
468478
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
469479
serialized_engine = engine.serialize_with_config(serialization_config)
470480

481+
del engine
482+
gc.collect()
483+
torch.cuda.empty_cache()
484+
471485
if isinstance(
472486
compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
473487
):

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28+
OFFLOAD_MODULE_TO_CPU,
2829
OPTIMIZATION_LEVEL,
2930
PASS_THROUGH_BUILD_FAILURES,
3031
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -140,6 +141,7 @@ class CompilationSettings:
140141
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
141142
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
142143
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143145

144146

145147
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,9 +734,10 @@ def run(
734734
builder_config, self.compilation_settings.timing_cache_path
735735
)
736736
# TODO: Memory control prototyping. Under discussion
737-
del self.module
738-
gc.collect()
739-
torch.cuda.empty_cache()
737+
if self.compilation_settings.offload_module_to_cpu:
738+
del self.module
739+
gc.collect()
740+
torch.cuda.empty_cache()
740741
serialized_engine = self.builder.build_serialized_network(
741742
self.ctx.net, builder_config
742743
)

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 20 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@
22
import logging
33
from copy import deepcopy
44
from enum import Enum, auto
5-
from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union
5+
from typing import Any, Dict, Iterator, Optional, Union
66

77
import numpy as np
88
import torch
9-
from torch.fx.node import Target
109
from torch_tensorrt._Device import Device
11-
from torch_tensorrt._enums import EngineCapability, dtype
1210
from torch_tensorrt.dynamo import _defaults
1311
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
1412
from torch_tensorrt.dynamo._refit import refit_module_weights
15-
from torch_tensorrt.dynamo._settings import CompilationSettings
1613
from torch_tensorrt.dynamo.utils import (
1714
check_output_equal,
1815
to_torch_device,
@@ -63,35 +60,8 @@ def __init__(
6360
pytorch_model: torch.nn.Module,
6461
*,
6562
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
66-
disable_tf32: bool = _defaults.DISABLE_TF32,
67-
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
68-
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
69-
enabled_precisions: Set[
70-
Union[torch.dtype, dtype]
71-
] = _defaults.ENABLED_PRECISIONS,
72-
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
73-
immutable_weights: bool = False,
74-
debug: bool = _defaults.DEBUG,
75-
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
76-
workspace_size: int = _defaults.WORKSPACE_SIZE,
77-
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
78-
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
79-
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
80-
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
81-
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
82-
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
83-
torch_executed_ops: Optional[Collection[Target]] = None,
84-
torch_executed_modules: Optional[List[str]] = None,
85-
pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES,
86-
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
87-
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
88-
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
8963
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
90-
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
91-
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
92-
dryrun: bool = _defaults.DRYRUN,
93-
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
94-
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
64+
immutable_weights: bool = False,
9565
**kwargs: Any,
9666
) -> None:
9767
"""
@@ -154,50 +124,15 @@ def __init__(
154124
self.exp_program: Any = None
155125
self.arg_inputs: tuple[Any, ...] = tuple()
156126
self.kwarg_inputs: dict[str, Any] = {}
157-
device = to_torch_tensorrt_device(device)
158-
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
127+
self.additional_settings = kwargs
128+
self.use_python_runtime = use_python_runtime
129+
self.trt_device = to_torch_tensorrt_device(device)
159130
assert (
160131
not immutable_weights
161-
), "`immutable_weights` has to be False for a MutableTorchTensorRTModule."
162-
compilation_options = {
163-
"enabled_precisions": (
164-
enabled_precisions
165-
if enabled_precisions
166-
else _defaults.ENABLED_PRECISIONS
167-
),
168-
"debug": debug,
169-
"device": device,
170-
"assume_dynamic_shape_support": assume_dynamic_shape_support,
171-
"workspace_size": workspace_size,
172-
"min_block_size": min_block_size,
173-
"torch_executed_ops": (
174-
torch_executed_ops if torch_executed_ops is not None else set()
175-
),
176-
"pass_through_build_failures": pass_through_build_failures,
177-
"max_aux_streams": max_aux_streams,
178-
"version_compatible": version_compatible,
179-
"optimization_level": optimization_level,
180-
"use_python_runtime": use_python_runtime,
181-
"truncate_double": truncate_double,
182-
"use_fast_partitioner": use_fast_partitioner,
183-
"num_avg_timing_iters": num_avg_timing_iters,
184-
"enable_experimental_decompositions": enable_experimental_decompositions,
185-
"require_full_compilation": require_full_compilation,
186-
"disable_tf32": disable_tf32,
187-
"sparse_weights": sparse_weights,
188-
"immutable_weights": immutable_weights,
189-
"engine_capability": engine_capability,
190-
"dla_sram_size": dla_sram_size,
191-
"dla_local_dram_size": dla_local_dram_size,
192-
"dla_global_dram_size": dla_global_dram_size,
193-
"dryrun": dryrun,
194-
"hardware_compatible": hardware_compatible,
195-
"timing_cache_path": timing_cache_path,
196-
}
132+
), "`immutable_weights has to be False for a MutableTorchTensorRTModule"
133+
197134
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
198135
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
199-
200-
self.settings = CompilationSettings(**compilation_options)
201136
self.run_info: Optional[tuple[Any, ...]] = None
202137
self.state_dict_metadata: dict[str, torch.Size] = {}
203138
self._store_state_dict_metadata()
@@ -293,7 +228,7 @@ def update_refit_condition(self) -> None:
293228
# to determine whether refit/recompilation is needed. If the output is the same, no further process needed.
294229
if self.run_info:
295230
args, kwargs, result = self.run_info
296-
self.original_model.to(to_torch_device(self.settings.device))
231+
self.original_model.to(to_torch_device(self.trt_device))
297232
new_result = self.original_model(*args, **kwargs)
298233
self.original_model.cpu()
299234
torch.cuda.empty_cache()
@@ -325,7 +260,7 @@ def refit_gm(self) -> None:
325260
MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module.
326261
If it fails to catch the changes, please call this function manually to update the TRT graph module.
327262
"""
328-
self.original_model.to(to_torch_device(self.settings.device))
263+
self.original_model.to(to_torch_device(self.trt_device))
329264
if self.exp_program is None:
330265
self.exp_program = torch.export.export(
331266
self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs
@@ -356,25 +291,30 @@ def compile(self) -> None:
356291
If it fails to catch the changes, please call this function manually to recompile the TRT graph module.
357292
"""
358293
# Export the module
359-
self.original_model.to(to_torch_device(self.settings.device))
360-
self.exp_program = torch.export.export(
294+
self.original_model.to(to_torch_device(self.trt_device))
295+
self.exp_program = torch.export._trace._export(
361296
self.original_model,
362297
self.arg_inputs,
363298
kwargs=self.kwarg_inputs,
364299
dynamic_shapes=self._get_total_dynamic_shapes(),
300+
strict=False,
301+
allow_complex_guards_as_runtime_asserts=True,
302+
# **self.additional_settings
365303
)
366304
self.gm = dynamo_compile(
367305
self.exp_program,
368306
arg_inputs=self.arg_inputs,
369307
kwarg_inputs=self.kwarg_inputs,
370-
**self.settings.__dict__,
308+
immutable_weights=False,
309+
use_python_runtime=self.use_python_runtime,
310+
**self.additional_settings,
371311
)
372312
self.original_model.cpu()
373313
torch.cuda.empty_cache()
374314

375315
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
376316

377-
if not self.arg_inputs:
317+
if not self.arg_inputs and not self.kwarg_inputs:
378318
logger.info("First time compilation initiated. This may take some time.")
379319
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
380320
self._store_inputs(args, kwargs)
@@ -628,7 +568,7 @@ def _check_tensor_shapes_with_dynamic_shapes(
628568
def save(module: Any, path: str) -> None:
629569
# Cast the object back to MutableTorchTensorRTModule to save
630570
assert (
631-
not module.settings.use_python_runtime
571+
not module.use_python_runtime
632572
), "Python runtime does not support serialization. Save failed."
633573
module.init_finished = False
634574
module.__class__ = MutableTorchTensorRTModule
@@ -658,7 +598,7 @@ def load(path: str) -> Any:
658598
module.pytorch_model = _make_refit_change_trigger(
659599
module.original_model, module.refit_state
660600
)
661-
module.original_model.to(to_torch_device(module.settings.device))
601+
module.original_model.to(to_torch_device(module.device))
662602
module.exp_program = torch.export.export(
663603
module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs
664604
)

0 commit comments

Comments
 (0)