-
Notifications
You must be signed in to change notification settings - Fork 2k
Description
System Info
H100
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
python3 examples/auto_deploy/build_and_run_ad.py --model nvidia/Llama-3.1-70B-Instruct-FP8 --args.yaml-extra examples/auto_deploy/model_registry/configs/dashboard_default.yaml --args.yaml-extra examples/auto_deploy/model_registry/configs/world_size_8.yaml --benchmark.enabled
Expected behavior
Model should build and run
actual behavior
Model fails with:
Error: AssertionError: fake mode from fake tensor input 0 doesn't match mode from fake tensor input 1
additional notes
The core issue: Multiple FakeTensorMode instances are created (one during model building with init_empty_weights, potentially another during torch.compile), and PyTorch's fake mode detection fails when these are used together.
What's happening:
- Model building phase (each MPI process):
• build_model("meta") uses init_empty_weights context manager
• init_empty_weights creates a FakeTensorMode instance per process
• Model parameters/buffers become FakeTensor objects tied to that FakeTensorMode - Compile phase (compile_model transform):
• torch.compile(self.model, dynamic=True) wraps the model
• When the compiled model runs, torch.inductor's pattern matcher processes operations
• The pattern matcher calls torch._dynamo.utils.detect_fake_mode(args) on its inputs - The mismatch:
• Model parameters are FakeTensor objects with FakeTensorMode A (from init_empty_weights)
• Input tensors from cm.named_args are real CUDA/CPU tensors
• During compilation, these inputs may be converted to FakeTensor objects
• If a new FakeTensorMode is created during compilation (or from a different process/context), you get FakeTensorMode B
• When both are used together, detect_fake_mode finds two different FakeTensorMode instances and raises: AssertionError: fake mode from fake tensor input 0 doesn't match mode from fake tensor input 1
Why it happens in distributed execution:
• Each MPI process creates its own FakeTensorMode when building models
• torch.compile/torch.inductor may create additional FakeTensorMode instances
• PyTorch requires all FakeTensor objects in an operation to share the same FakeTensorMode instance
• In distributed setups, these instances can conflict when tensors interact
Before submitting a new issue...
- Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status