36
36
TorchTensorRTModule ,
37
37
)
38
38
from torch_tensorrt .dynamo .utils import (
39
+ CPU_DEVICE ,
39
40
check_module_output ,
41
+ delete_module ,
40
42
get_model_device ,
41
43
get_torch_inputs ,
42
44
set_log_level ,
@@ -314,9 +316,6 @@ def refit_module_weights(
314
316
get_decompositions (settings .enable_experimental_decompositions )
315
317
)
316
318
new_gm = new_weight_module .module ()
317
- # TODO: Memory control prototyping. Under discussion
318
- if settings .offload_module_to_cpu :
319
- new_weight_module .module ().to ("cpu" )
320
319
321
320
logger .debug ("Input graph: " + str (new_gm .graph ))
322
321
# Apply lowering on the graph module
@@ -395,7 +394,7 @@ def refit_module_weights(
395
394
396
395
# Iterate over all components that can be accelerated
397
396
# Generate the corresponding TRT Module for those
398
-
397
+ new_weight_module . module (). to ( CPU_DEVICE )
399
398
for name , new_submodule in new_partitioned_module .named_children ():
400
399
# Refit each submodule
401
400
# Extract engine from the submodule
@@ -498,11 +497,7 @@ def refit_module_weights(
498
497
settings = settings ,
499
498
weight_name_map = None ,
500
499
)
501
- # TODO: Memory control prototyping. Under discussion
502
- if settings .offload_module_to_cpu :
503
- del new_submodule
504
- gc .collect ()
505
- torch .cuda .empty_cache ()
500
+ delete_module (new_submodule )
506
501
507
502
# clear EXCLUDE_WEIGHTS flag
508
503
serialization_config = engine .create_serialization_config ()
@@ -525,20 +520,18 @@ def refit_module_weights(
525
520
gc .collect ()
526
521
torch .cuda .empty_cache ()
527
522
528
- # TODO: Memory control prototyping. Under discussion
529
- if settings .offload_module_to_cpu :
530
- del new_partitioned_module
531
- gc .collect ()
532
- torch .cuda .empty_cache ()
523
+ delete_module (new_partitioned_module )
533
524
534
525
if verify_output and arg_inputs is not None :
526
+ new_gm .to (torch .cuda .current_device ())
535
527
if check_module_output (
536
528
new_module = new_gm ,
537
529
refitted_module = compiled_module ,
538
530
arg_inputs = torch_inputs ,
539
531
kwarg_inputs = torch_kwarg_inputs ,
540
532
):
541
533
logger .info ("Refitting Succeed!" )
534
+ new_gm .to (CPU_DEVICE )
542
535
else :
543
536
if weight_name_map :
544
537
logger .warning (
@@ -554,6 +547,7 @@ def refit_module_weights(
554
547
in_place = in_place ,
555
548
)
556
549
logger .error ("Refitting Failed! The outputs do not match." )
550
+ new_gm .to (CPU_DEVICE )
557
551
else :
558
552
logger .info ("Refitting Completed! Output verification skipped." )
559
553
0 commit comments