-
Notifications
You must be signed in to change notification settings - Fork 570
gpt-oss model enablement #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
gpt-oss model enablement #1754
Conversation
48b2a11
to
07c0ff4
Compare
Need to rebase onto #1776 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great in general. Left some comments. May need some rebase on recent & near-future development.
Summary of current status: There are some prerequisite PRs:
Once these PRs are landed, I will refactor:
|
…ks but reduces mfu for 20b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address all the comments before landing. I would appreciate that if you add the reason why we cannot do AuxOutput to the code. Thanks!
n_kv_heads: int = 8 | ||
sliding_window_size: int = 128 | ||
attn_mask_type: str = "causal" | ||
use_flex_attn: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I explicitly leave the parameter here, to be compatible with https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L428 here, where we need to call get_attention_masks.
But I added a notes here to prevent user change this flag to false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I found a tricky numerical bug in TP. Maybe we can disable it for now.
- Up to `window_size - 1` previous tokens | ||
Args: | ||
window_size: The maximum number of tokens to attend to (including current token). | ||
Must be >= 1. A window_size of 1 means attend only to self. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to raise ValueError
if user didn't set window_size >= 1
mlp1_weight = self.mlp1_weight.to_local() | ||
mlp1_bias = self.mlp1_bias.to_local() | ||
mlp2_weight = self.mlp2_weight.to_local() | ||
mlp2_bias = self.mlp2_bias.to_local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might not be correct.
When they are dtensors, x * mlp2_weight + mlp2_bias
will have placements Partial
+ Replicate
, and sharding prop can automatically first make Replicate -> Partial
then perform the addition.
However, when we do to_local
, DTensor placement info is discarded, so instead of adding mlp2_bias
, the net effect will be adding tp_degree * mlp2_bias
.
I don't have clean way to solve this. For forward correctness, we can do mlp2_bias / tp_degree
to cancel the extra reduction effect, but the backward will have an extra * tp_degree
. Can we wrap mlp2_bias / tp_degree
in torch.no_grad so the backward doesn't perform * tp_degree
?
You can also disable TP / ETP altogether for gpt-oss for now and leave a TODO.
cc @ezyang @fmassa on difficulties of making TP correct in a local tensor region, when there is bias involved.
self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, hidden_dim * 2))) | ||
self.mlp1_bias = nn.Parameter(torch.empty((num_experts, hidden_dim * 2))) | ||
self.mlp2_weight = nn.Parameter(torch.empty((num_experts, hidden_dim, dim))) | ||
self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is different from main moe.py, where we init the weight params to have shape (num_experts, out_dim, in_dim)
and do transpose before using them. The point is for hardware efficiency (mainly in low-precision case). We also need to change the TP / ETP plans to adapt.
See #1517
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in | ||
# as the first argument, which will cause an error. | ||
# `FlexAttentionWrapper._compiled_flex_attn` is correct. | ||
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah can you explain this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API will be removed in a future release
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I also noticed the return_lse
is being deprecated, the reason we use it here is we want to use TP annotation to change the lse
tensor back to a DTensor with placement Shard(1) (in TP region, it's a plain tensor). https://github.com/pytorch/torchtitan/pull/1754/files#diff-3448dcaf6e8b68f3b66a8e1dd298273de3702f93de406569426cd9e03fd7f97bR222. We can not annotate an AuxOutput() object directly using TP APIs. And because we want to keep model code parallelism-free, we don't want to manually turn AuxOut.lse
into a DTensor.
I think an alternative way is to handle it in FlexAttentionWarpper, if this is a better way, I will create another PR to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some minor final comments
- Up to `window_size - 1` previous tokens | ||
Args: | ||
window_size: The maximum number of tokens to attend to (including current token). | ||
Must be >= 1. A window_size of 1 means attend only to self. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment not addressed
self.use_grouped_mm = use_grouped_mm | ||
self.swiglu_limit = swiglu_limit | ||
|
||
self.mlp1_weight = nn.Parameter(torch.empty((num_experts, hidden_dim * 2, dim))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment to indicate which dim is input dim, which is output dim.
Keep developing on top of #1559. Thanks @KhoomeiK for initial contribution!
Initialized by the same seed checkpoint, set seed=0 and deterministic = True.
GPT-oss

Run 1: dp_shard = 2
Run 2: dp_shard = 2, TP degree = 2 (NGPU=4)

Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4)

Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4)

Run 5: dp_shard=2, EP degree = 2 (NGPU=2)
