Skip to content

[Bug][AutoDeploy] Distributed execution fails: FakeTensorMode mismatch during torch.compile with world_size > 1 #10545

@tcherckez-nvidia

Description

@tcherckez-nvidia

System Info

H100

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

python3 examples/auto_deploy/build_and_run_ad.py --model nvidia/Llama-3.1-70B-Instruct-FP8 --args.yaml-extra examples/auto_deploy/model_registry/configs/dashboard_default.yaml --args.yaml-extra examples/auto_deploy/model_registry/configs/world_size_8.yaml --benchmark.enabled

Expected behavior

Model should build and run

actual behavior

Model fails with:
Error: AssertionError: fake mode from fake tensor input 0 doesn't match mode from fake tensor input 1

additional notes

The core issue: Multiple FakeTensorMode instances are created (one during model building with init_empty_weights, potentially another during torch.compile), and PyTorch's fake mode detection fails when these are used together.

What's happening:

  1. Model building phase (each MPI process):
    • build_model("meta") uses init_empty_weights context manager
    • init_empty_weights creates a FakeTensorMode instance per process
    • Model parameters/buffers become FakeTensor objects tied to that FakeTensorMode
  2. Compile phase (compile_model transform):
    • torch.compile(self.model, dynamic=True) wraps the model
    • When the compiled model runs, torch.inductor's pattern matcher processes operations
    • The pattern matcher calls torch._dynamo.utils.detect_fake_mode(args) on its inputs
  3. The mismatch:
    • Model parameters are FakeTensor objects with FakeTensorMode A (from init_empty_weights)
    • Input tensors from cm.named_args are real CUDA/CPU tensors
    • During compilation, these inputs may be converted to FakeTensor objects
    • If a new FakeTensorMode is created during compilation (or from a different process/context), you get FakeTensorMode B
    • When both are used together, detect_fake_mode finds two different FakeTensorMode instances and raises: AssertionError: fake mode from fake tensor input 0 doesn't match mode from fake tensor input 1

Why it happens in distributed execution:
• Each MPI process creates its own FakeTensorMode when building models
• torch.compile/torch.inductor may create additional FakeTensorMode instances
• PyTorch requires all FakeTensor objects in an operation to share the same FakeTensorMode instance
• In distributed setups, these instances can conflict when tensors interact

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Disaggregated serving<NV>Deploying with separated, distributed components (params, kv-cache, compute). Arch & perf.Pytorch<NV>Pytorch backend related issuesbugSomething isn't working

    Type

    Projects

    Status

    Backlog

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions