Skip to content

Commit 98cbd76

Browse files
committed
Fixed the comments
1 parent 92ae47d commit 98cbd76

File tree

9 files changed

+32
-52
lines changed

9 files changed

+32
-52
lines changed

examples/apps/flux-demo.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import gradio as gr
44
import torch
55
import torch_tensorrt
6-
from diffusers import FluxPipeline, StableDiffusionPipeline
7-
from torch.export._trace import _export
6+
from diffusers import FluxPipeline
87

98
DEVICE = "cuda:0"
109
pipe = FluxPipeline.from_pretrained(
@@ -43,13 +42,7 @@
4342
"debug": False,
4443
"use_python_runtime": True,
4544
"immutable_weights": False,
46-
# "cache_built_engines": True,
47-
# "reuse_cached_engines": True,
48-
# "timing_cache_path": "/home/engine_cache/flux.bin",
49-
# "engine_cache_size": 40 * 1 << 30,
50-
# "enable_weight_streaming": True,
51-
# "weight_streaming_budget": 8 * 1 << 30
52-
# "enable_cuda_graph": True,
45+
"enable_cuda_graph": True,
5346
}
5447

5548
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
@@ -69,7 +62,6 @@ def generate_image(prompt, inference_step, batch_size=2):
6962

7063
generate_image(["Test"], 2)
7164
torch.cuda.empty_cache()
72-
# torch_tensorrt.MutableTorchTensorRTModule.save(trt_gm, "weight_streaming_Flux.pkl")
7365

7466

7567
def model_change(model):
@@ -97,8 +89,6 @@ def load_lora(path):
9789

9890

9991
generate_image(["Test"], 2)
100-
# load_lora("")
101-
# generate_image(["A golden retriever holding a sign to code"], 2)
10292

10393
# Create Gradio interface
10494
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:

examples/dynamo/refit_engine_example.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102

103103
# Check the output
104+
model2.to("cuda")
104105
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
105106
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
106107
assert torch.allclose(

py/torch_tensorrt/dynamo/_compiler.py

-2
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ def compile(
422422
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
423423
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
424424
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
425-
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
426425
**kwargs: Any,
427426
) -> torch.fx.GraphModule:
428427
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -667,7 +666,6 @@ def compile(
667666
"enable_weight_streaming": enable_weight_streaming,
668667
"tiling_optimization_level": tiling_optimization_level,
669668
"l2_limit_for_tiling": l2_limit_for_tiling,
670-
"offload_module_to_cpu": offload_module_to_cpu,
671669
}
672670

673671
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
TILING_OPTIMIZATION_LEVEL = "none"
5050
L2_LIMIT_FOR_TILING = -1
5151
USE_DISTRIBUTED_MODE_TRACE = False
52-
OFFLOAD_MODULE_TO_CPU = True
5352

5453

5554
def default_device() -> Device:

py/torch_tensorrt/dynamo/_refit.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
TorchTensorRTModule,
3737
)
3838
from torch_tensorrt.dynamo.utils import (
39+
CPU_DEVICE,
3940
check_module_output,
41+
delete_module,
4042
get_model_device,
4143
get_torch_inputs,
4244
set_log_level,
@@ -314,9 +316,6 @@ def refit_module_weights(
314316
get_decompositions(settings.enable_experimental_decompositions)
315317
)
316318
new_gm = new_weight_module.module()
317-
# TODO: Memory control prototyping. Under discussion
318-
if settings.offload_module_to_cpu:
319-
new_weight_module.module().to("cpu")
320319

321320
logger.debug("Input graph: " + str(new_gm.graph))
322321
# Apply lowering on the graph module
@@ -395,7 +394,7 @@ def refit_module_weights(
395394

396395
# Iterate over all components that can be accelerated
397396
# Generate the corresponding TRT Module for those
398-
397+
new_weight_module.module().to(CPU_DEVICE)
399398
for name, new_submodule in new_partitioned_module.named_children():
400399
# Refit each submodule
401400
# Extract engine from the submodule
@@ -498,11 +497,7 @@ def refit_module_weights(
498497
settings=settings,
499498
weight_name_map=None,
500499
)
501-
# TODO: Memory control prototyping. Under discussion
502-
if settings.offload_module_to_cpu:
503-
del new_submodule
504-
gc.collect()
505-
torch.cuda.empty_cache()
500+
delete_module(new_submodule)
506501

507502
# clear EXCLUDE_WEIGHTS flag
508503
serialization_config = engine.create_serialization_config()
@@ -525,20 +520,18 @@ def refit_module_weights(
525520
gc.collect()
526521
torch.cuda.empty_cache()
527522

528-
# TODO: Memory control prototyping. Under discussion
529-
if settings.offload_module_to_cpu:
530-
del new_partitioned_module
531-
gc.collect()
532-
torch.cuda.empty_cache()
523+
delete_module(new_partitioned_module)
533524

534525
if verify_output and arg_inputs is not None:
526+
new_gm.to(torch.cuda.current_device())
535527
if check_module_output(
536528
new_module=new_gm,
537529
refitted_module=compiled_module,
538530
arg_inputs=torch_inputs,
539531
kwarg_inputs=torch_kwarg_inputs,
540532
):
541533
logger.info("Refitting Succeed!")
534+
new_gm.to(CPU_DEVICE)
542535
else:
543536
if weight_name_map:
544537
logger.warning(
@@ -554,6 +547,7 @@ def refit_module_weights(
554547
in_place=in_place,
555548
)
556549
logger.error("Refitting Failed! The outputs do not match.")
550+
new_gm.to(CPU_DEVICE)
557551
else:
558552
logger.info("Refitting Completed! Output verification skipped.")
559553

py/torch_tensorrt/dynamo/_settings.py

-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
28-
OFFLOAD_MODULE_TO_CPU,
2928
OPTIMIZATION_LEVEL,
3029
PASS_THROUGH_BUILD_FAILURES,
3130
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -141,7 +140,6 @@ class CompilationSettings:
141140
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
142141
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
143142
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144-
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
145143

146144

147145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def forward(a, b, c=0, d=0):
209209

210210
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
211211

212-
def _get_total_dynamic_shapes(self) -> dict[str, Any] | None:
212+
def _get_total_dynamic_shapes(self) -> Union[dict[str, Any], None]:
213213
if not self.arg_dynamic_shapes and not self.kwarg_dynamic_shapes:
214214
return None
215215
total_dynamic_shape = {}
@@ -490,7 +490,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
490490
def to(self, *args: Any, **kwargs: Any) -> None:
491491
logger.warning(
492492
"Trying to move the original PyTorch model. This will cause CPU offloading failing and increase GPU memory usage."
493-
+ "If this is absolute necessary, please call module.pytorch_model.to(...)"
493+
+ "If this is absolute necessary, please call module.pytorch_model.to(...) \n"
494+
+ "The model is still on the original device."
494495
)
495496

496497
@property

py/torch_tensorrt/runtime/_cudagraphs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def __enter__(self) -> torch.nn.Module | torch.fx.GraphModule:
7777
self.old_module = self.compiled_module.gm
7878
self.compiled_module.gm = get_cuda_graph_module(self.compiled_module.gm)
7979
return self.compiled_module
80-
81-
return get_cuda_graph_module(self.compiled_module)
80+
else:
81+
return get_cuda_graph_module(self.compiled_module)
8282

8383
def __exit__(self, *args: Any) -> None:
8484
# Set cudagraphs back to old mode

0 commit comments

Comments
 (0)