-
Notifications
You must be signed in to change notification settings - Fork 362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nccl ops correction changes #3387
Changes from all commits
95999b9
c6c87bb
7476b05
d0c9671
7200055
de48feb
de54bb1
dbe3ef2
d3ce998
8d6ef41
a58ba49
db25058
825ae45
c556dec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,11 +10,11 @@ | |
from torch._dynamo.backends.common import aot_autograd | ||
from torch._dynamo.utils import detect_fake_mode | ||
from torch._functorch.aot_autograd import aot_export_joint_simple | ||
from torch.distributed.tensor import DTensor | ||
from torch_tensorrt.dynamo import CompilationSettings | ||
from torch_tensorrt.dynamo._compiler import compile_module | ||
from torch_tensorrt.dynamo.lowering import ( | ||
get_decompositions, | ||
modify_reshape_complex_nodes, | ||
post_lowering, | ||
remove_detach, | ||
remove_sym_nodes, | ||
|
@@ -52,25 +52,39 @@ def aot_torch_tensorrt_aten_backend( | |
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any | ||
) -> torch.nn.Module: | ||
settings, engine_cache = parse_dynamo_kwargs(kwargs) | ||
if settings.use_aot_joint_export: | ||
return _pretraced_backend(gm, sample_inputs, settings, engine_cache) | ||
logger.debug("Wrapping the backend with aot_autograd\n") | ||
_pretraced_backend_autograd = functools.partial( | ||
_pretraced_backend, settings=settings, engine_cache=engine_cache | ||
) | ||
settings_aot_autograd = {} | ||
settings_aot_autograd["decompostions"] = get_decompositions( | ||
settings.enable_experimental_decompositions | ||
) | ||
# This is added since detach lowering leads to alias nodes | ||
# Error - View operation returned a tensor that is the same as the input base tensor | ||
# torch nop_decompositions in torch/_decomp/decompositions.py | ||
if aten.detach in settings_aot_autograd["decompositions"]: | ||
del settings_aot_autograd["decompositions"][aten.detach] | ||
return aot_autograd( | ||
fw_compiler=_pretraced_backend_autograd, | ||
decompositions=get_decompositions(settings.enable_experimental_decompositions), | ||
)(gm, sample_inputs) | ||
|
||
if settings.use_distributed_mode_trace: | ||
logger.debug( | ||
"Wrapping the backend with aot_autograd for Distributed examples\n" | ||
) | ||
_pretraced_backend_autograd = functools.partial( | ||
_pretraced_backend, settings=settings, engine_cache=engine_cache | ||
) | ||
settings_aot_autograd = {} | ||
settings_aot_autograd["decompositions"] = get_decompositions( | ||
settings.enable_experimental_decompositions | ||
) | ||
# This is added since detach lowering leads to alias nodes | ||
# Error - View operation returned a tensor that is the same as the input base tensor | ||
# torch nop_decompositions in torch/_decomp/decompositions.py | ||
# transpose key deleted since not desirable to lower it to permute | ||
to_delete = { | ||
key | ||
for key in settings_aot_autograd["decompositions"] | ||
if "detach" in key._name | ||
} | ||
for key in to_delete: | ||
del settings_aot_autograd["decompositions"][key] | ||
|
||
return aot_autograd( | ||
fw_compiler=_pretraced_backend_autograd, | ||
decompositions=settings_aot_autograd["decompositions"], | ||
)(gm, sample_inputs) | ||
if any(isinstance(tensor, DTensor) for tensor in sample_inputs): | ||
logger.warning( | ||
"It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported in aot_export_joint_simple" | ||
) | ||
return _pretraced_backend(gm, sample_inputs, settings, engine_cache) | ||
|
||
|
||
def _pretraced_backend( | ||
|
@@ -110,18 +124,8 @@ def _pretraced_backend( | |
# Remove detach nodes | ||
remove_detach(gm, settings) | ||
|
||
complexInputIndices = [] | ||
for i, torch_input in enumerate(torch_inputs): | ||
if torch_inputs[i].dtype == torch.complex64: | ||
complexInputIndices.append(i) | ||
torch_input_real = torch_inputs[i].real | ||
torch_input_imaginary = torch_inputs[i].imag | ||
torch_inputs[i] = torch.stack( | ||
(torch_input_real, torch_input_imaginary), dim=-1 | ||
) | ||
|
||
# Invoke AOTAutograd to translate operators to aten | ||
if settings.use_aot_joint_export: | ||
if not settings.use_distributed_mode_trace: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this check now ? since we return the aot_autograd block before ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we would be needing this since while wrapping it in |
||
gm = aot_export_joint_simple( | ||
gm, | ||
sample_inputs, | ||
|
@@ -137,12 +141,6 @@ def _pretraced_backend( | |
|
||
logger.debug("Lowered Input graph:\n " + str(gm.graph)) | ||
|
||
if complexInputIndices: | ||
modify_reshape_complex_nodes(gm, complexInputIndices) | ||
logger.debug( | ||
"Input graph after modifying complex nodes:\n " + str(gm.graph) | ||
) | ||
|
||
torchtrt_inputs = prepare_inputs( | ||
torch_inputs, disable_memory_format_check=True | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider changing the message to - "Using aot_autograd to trace the graph. Enable this only if the model includes distributed operations"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about only,
Using AOTAutograd tracer for model lowering
or something to that extent