-
Notifications
You must be signed in to change notification settings - Fork 319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MoE][PoC] Expert Parallel: dp2ep #732
base: gh/tianyu-l/26/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
ghstack-source-id: 17160930f23950b91faca7b822cd3e7f9d075f7d Pull Request resolved: #732
ghstack-source-id: 2a70ed917b742c32118ef5ca02f161f833ce46bc Pull Request resolved: #732
Expert parallelism degree. 1 means disabled. | ||
When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree. | ||
When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree, | ||
where k >= 1 and k | data_parallel_shard_degree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment isn't clear.
What does k | data_parallel_shard_degree mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It stands for data_parallel_shard_degree % k == 0
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension. | ||
""", | ||
choices=["none", "tp", "tp2ep", "dp2ep"], | ||
help="Expert Parallel mode", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dp2ep here would be using the DP mesh to shard non-shared experts on the num_experts dimension? If so, could you make it clear in the comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dp2ep would use "the entire cp
mesh (if existing) + part of dp_shard
mesh (namely dp_shard_2
)" to shard non-shared experts.
Sorry for the confusion -- these PRs are not meant for landing without change. We'll definitely polish the descriptions later. Reading the parallel_dims.py
might be more informative for now.
Stack from ghstack (oldest at bottom):
Temporary changes to unblock exploration
foreach
andclip_grad_norm_
off, as not all parameters are DTensors on the same meshes (e.g. (1) MoE non-shared experts and other params are on different FSDP meshes, and (2)moe.router.gate
is a replicate torch.Tensor)Also need to
full_graph=False
because there will be an additional FSDP inside a TransformerBlock at the non shared experts level.Things won't work
Not including