Skip to content

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Sep 24, 2025

Keep developing on top of #1559. Thanks @KhoomeiK for initial contribution!

Initialized by the same seed checkpoint, set seed=0 and deterministic = True.

GPT-oss
Run 1: dp_shard = 2
Screenshot 2025-10-17 at 3 34 20 PM

Run 2: dp_shard = 2, TP degree = 2 (NGPU=4)
Screenshot 2025-10-21 at 8 25 36 PM

Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4)
Screenshot 2025-10-21 at 8 27 34 PM

Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4)
Screenshot 2025-10-21 at 8 30 41 PM

Run 5: dp_shard=2, EP degree = 2 (NGPU=2)
Screenshot 2025-10-17 at 3 35 41 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 24, 2025
@wwwjn wwwjn force-pushed the gpt-oss branch 2 times, most recently from 48b2a11 to 07c0ff4 Compare September 30, 2025 04:34
@wwwjn
Copy link
Contributor Author

wwwjn commented Sep 30, 2025

Need to rebase onto #1776

@wwwjn wwwjn marked this pull request as ready for review September 30, 2025 23:01
@wwwjn wwwjn changed the title [WIP] gpt-oss model enablement gpt-oss model enablement Sep 30, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great in general. Left some comments. May need some rebase on recent & near-future development.

@wwwjn
Copy link
Contributor Author

wwwjn commented Oct 6, 2025

Summary of current status:

There are some prerequisite PRs:

  1. FlexAttn refactor Refactor attention and make attention mask an argument to the model #1776
  2. EP refactor [EP] add initial support for NVSHMEM-based all-to-all #1569
  3. refactor freqs_cis as a input of model.forward() [RFC] Lift freqs_cis as an input of models #1797

Once these PRs are landed, I will refactor:

  1. FlexAttention, adding sliding_window attention mask, and make it orthogonal to block_causal mask.
  2. ExpertParallel() and ExpertTensorParallel() class to reuse as much as possible, as keep aligns with main EP implementation

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address all the comments before landing. I would appreciate that if you add the reason why we cannot do AuxOutput to the code. Thanks!

n_kv_heads: int = 8
sliding_window_size: int = 128
attn_mask_type: str = "causal"
use_flex_attn: bool = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explicitly leave the parameter here, to be compatible with https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L428 here, where we need to call get_attention_masks.

But I added a notes here to prevent user change this flag to false

@wwwjn wwwjn requested a review from tianyu-l October 19, 2025 03:33
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found a tricky numerical bug in TP. Maybe we can disable it for now.

- Up to `window_size - 1` previous tokens
Args:
window_size: The maximum number of tokens to attend to (including current token).
Must be >= 1. A window_size of 1 means attend only to self.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to raise ValueError if user didn't set window_size >= 1

mlp1_weight = self.mlp1_weight.to_local()
mlp1_bias = self.mlp1_bias.to_local()
mlp2_weight = self.mlp2_weight.to_local()
mlp2_bias = self.mlp2_bias.to_local()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be correct.

When they are dtensors, x * mlp2_weight + mlp2_bias will have placements Partial + Replicate, and sharding prop can automatically first make Replicate -> Partial then perform the addition.

However, when we do to_local, DTensor placement info is discarded, so instead of adding mlp2_bias, the net effect will be adding tp_degree * mlp2_bias.

I don't have clean way to solve this. For forward correctness, we can do mlp2_bias / tp_degree to cancel the extra reduction effect, but the backward will have an extra * tp_degree. Can we wrap mlp2_bias / tp_degree in torch.no_grad so the backward doesn't perform * tp_degree?

You can also disable TP / ETP altogether for gpt-oss for now and leave a TODO.

cc @ezyang @fmassa on difficulties of making TP correct in a local tensor region, when there is bias involved.

Comment on lines 151 to 154
self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, hidden_dim * 2)))
self.mlp1_bias = nn.Parameter(torch.empty((num_experts, hidden_dim * 2)))
self.mlp2_weight = nn.Parameter(torch.empty((num_experts, hidden_dim, dim)))
self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from main moe.py, where we init the weight params to have shape (num_experts, out_dim, in_dim) and do transpose before using them. The point is for hardware efficiency (mainly in low-precision case). We also need to change the TP / ETP plans to adapt.

See #1517

# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
# as the first argument, which will cause an error.
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah can you explain this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API will be removed in a future release

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I also noticed the return_lse is being deprecated, the reason we use it here is we want to use TP annotation to change the lse tensor back to a DTensor with placement Shard(1) (in TP region, it's a plain tensor). https://github.com/pytorch/torchtitan/pull/1754/files#diff-3448dcaf6e8b68f3b66a8e1dd298273de3702f93de406569426cd9e03fd7f97bR222. We can not annotate an AuxOutput() object directly using TP APIs. And because we want to keep model code parallelism-free, we don't want to manually turn AuxOut.lse into a DTensor.

I think an alternative way is to handle it in FlexAttentionWarpper, if this is a better way, I will create another PR to fix.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some minor final comments

- Up to `window_size - 1` previous tokens
Args:
window_size: The maximum number of tokens to attend to (including current token).
Must be >= 1. A window_size of 1 means attend only to self.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment not addressed

self.use_grouped_mm = use_grouped_mm
self.swiglu_limit = swiglu_limit

self.mlp1_weight = nn.Parameter(torch.empty((num_experts, hidden_dim * 2, dim)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment to indicate which dim is input dim, which is output dim.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants