-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MoE][PoC] Expert Parallel: dp2ep #732
base: gh/tianyu-l/26/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -375,15 +375,23 @@ def __init__(self): | |
The default value is 'allgather'. | ||
""", | ||
) | ||
self.parser.add_argument( | ||
"--experimental.expert_parallel_degree", | ||
type=int, | ||
default=1, | ||
help=""" | ||
Expert parallelism degree. 1 means disabled. | ||
When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree. | ||
When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree, | ||
where k >= 1 and k | data_parallel_shard_degree. | ||
""", | ||
) | ||
self.parser.add_argument( | ||
"--experimental.expert_parallel_mode", | ||
type=str, | ||
default="none", | ||
choices=["none", "tp", "tp2ep"], | ||
help=""" | ||
Expert Parallel mode. | ||
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension. | ||
""", | ||
choices=["none", "tp", "tp2ep", "dp2ep"], | ||
help="Expert Parallel mode", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dp2ep here would be using the DP mesh to shard non-shared experts on the num_experts dimension? If so, could you make it clear in the comments? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dp2ep would use "the entire |
||
) | ||
self.parser.add_argument( | ||
"--training.mixed_precision_param", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment isn't clear.
What does k | data_parallel_shard_degree mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It stands for
data_parallel_shard_degree % k == 0