Skip to content

Single-node specdec training broken due to FSDP incompatibility #1045

@benchislett

Description

@benchislett

Describe the bug

When following the sample example in the EAGLE speculative decoding training on an 8xB200 node, FSDP is used by default but crashes due to some incompatibility / unavailability of FSDP2.

Broken as of #922

Steps/Code to reproduce bug

Follow the default, Llama 3.2-1B EAGLE3 online training from the README:

./launch_train.sh --model meta-llama/Llama-3.2-1B-Instruct\
            --output_dir temp_logs/ \
            --data input_conversations/daring-anteater.jsonl  \
            --num_epochs 1 \
            --eagle_config eagle_config.json --ar_validate_steps -1 --train_bs 1

Crash log:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/root/eagle/ModelOptNew/examples/speculative_decoding/main.py", line 255, in <module>
[rank2]:     train()
[rank2]:   File "/root/eagle/ModelOptNew/examples/speculative_decoding/main.py", line 233, in train
[rank2]:     trainer = EagleTrainerWithAccLog(
[rank2]:   File "/root/eagle/ModelOptNew/.venv/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank2]:     return func(*args, **kwargs)
[rank2]:   File "/root/eagle/ModelOptNew/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 458, in __init__
[rank2]:     self.create_accelerator_and_postprocess()
[rank2]:   File "/root/eagle/ModelOptNew/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 5512, in create_accelerator_and_postprocess
[rank2]:     self.accelerator = Accelerator(**args)
[rank2]:   File "/root/eagle/ModelOptNew/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 475, in __init__
[rank2]:     self.parallelism_config._validate_accelerator(self)
[rank2]:   File "/root/eagle/ModelOptNew/.venv/lib/python3.10/site-packages/accelerate/parallelism_config.py", line 364, in _validate_accelerator
[rank2]:     raise ValueError(
[rank2]: ValueError: ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{Device} or DistributedType.DEEPSPEED, but got FSDP.

Expected behavior

Should not crash, should use pure DP (old behaviour) and not FSDP (new behaviour) (is there a difference when all the params can fit on one GPU?)

Who can help?

@yeyu-nvidia @ChenhanYu

System information

  • Container used (if applicable): ?
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 22.04.5 LTS
  • CPU architecture (x86_64, aarch64): x86_64
  • GPU name (e.g. H100, A100, L40S): NVIDIA B200
  • GPU memory size: 179.1 GB
  • Number of GPUs: 8
  • Library versions (if applicable):
    • Python: 3.10.12
    • ModelOpt version or commit hash: 0.43.0.dev110+g1070d895d
    • CUDA: ?
    • PyTorch: 2.9.1+cu128
    • Transformers: 4.57.6
    • TensorRT-LLM: ?
    • ONNXRuntime: ?
    • TensorRT: ?
  • Any other details that may help: ?

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions