-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
[oss] Init gpt-oss bf16 support #22508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for the gpt-oss model with bf16 precision. The changes are extensive, touching fused MoE kernels, layer configurations, and adding a new model definition file. My review has identified a couple of issues. Firstly, in vllm/model_executor/layers/fused_moe/layer.py
, the data type for MoE biases is hardcoded to torch.bfloat16
, which should be parameterized to support other dtypes. Secondly, there's a critical bug in the weight loading logic for down_proj
in vllm/model_executor/models/gpt_oss.py
, where an incorrect permutation is applied. I've provided suggestions to fix these issues.
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() | ||
param = params_dict[new_name] | ||
|
||
param.copy_(narrow_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The down_proj
weight is being incorrectly permuted. The down_proj
layer is a row-parallel layer, and its weight in vLLM has the shape (num_experts, hidden_size, intermediate_size_per_partition)
. After sharding, narrow_weight
already has this shape. The permute(0, 2, 1)
operation incorrectly changes it to (num_experts, intermediate_size_per_partition, hidden_size)
, which will cause shape mismatches and incorrect computations. The permutation should be removed.
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() | |
param = params_dict[new_name] | |
param.copy_(narrow_weight) | |
param = params_dict[new_name] | |
param.copy_(narrow_weight) | |
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
rename_mapping = { | ||
"self_attn": "attn", | ||
"input_layernorm.weight": "attn.norm.weight", | ||
"post_attention_layernorm.weight": "mlp.norm.weight", | ||
"embed_tokens": "embedding", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use WeightsMapper
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely agree, including the qkv mapping, planning to complete it in a subsequent PR
w1_bias: Optional[torch.Tensor], | ||
w2_bias: Optional[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also check the usage of fused_moe
, there are still tests and models (bert_with_rope.py
, deepseek.py
and minicpm.py
) using this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have checked as thoroughly as possible, but there may still be omissions
Add ready and see if it introduces any new bugs |
def oss_act(gate_up): | ||
alpha = 1.702 | ||
limit = 7.0 | ||
gate, up = gate_up[..., ::2], gate_up[..., 1::2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this implementation come from gpt-oss
repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1425,6 +1451,8 @@ def fused_experts_impl( | |||
hidden_states: torch.Tensor, | |||
w1: torch.Tensor, | |||
w2: torch.Tensor, | |||
w1_bias: Optional[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got this question from others. Can we fit bias into _zp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be able to, I'm not sure whether merging them is reasonable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bnellnm thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think using the existing _zp
arguments would be best if possible. They are currently only used by the existing triton implementation for int4/int8 and are ignored otherwise.
This may require flipping the sign since it looks like the existing ZP are subtracted rather than added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH, It's too hard to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having two different sets of arguments which are basically the same thing will be confusing. As a user I won't know whether to use _zp
or _bias
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that although they are similar, the semantics are different. Merging them might mislead users. At least in this PR, I don't want to implement this - I personally don't like it
@@ -160,7 +160,9 @@ def __init__( | |||
renormalize=True, | |||
quant_config=quant_config, | |||
prefix=f"{prefix}.experts", | |||
apply_router_weight_on_input=False) | |||
apply_router_weight_on_input=False, | |||
has_bias=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need it here? Can it be detect inside this class? Like if layer.w1_bias is not None
? Is doesn't seem as a clean solution.
Edit: Actually this is cleaner than not showing bias
. Thanks for adding it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I haven't misunderstood your question, whether to have bias should be decided at the model level. If there's no bias, we don't need to register bias parameters, similar to linear layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I just realized that as well. Previously, I just hard-coded into mxfp4 class. Thanks for adding it
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() | ||
param = params_dict[new_name] | ||
|
||
param.copy_(narrow_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why can't we use FusedMoE.weight_loader
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be possible to implement following llama4's approach, I will optimize in later PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I didn't realize llama4 uses the same approach to load all experts at once. Yes we can optimize this later.
"embed_tokens": "embedding", | ||
} | ||
|
||
def maybe_rename(name: str) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire function has a lot of duplicates with mxfp4 one? Is there any chance we can combine them to one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we can optimize this using WeightsMapper. This model's weight loader has many areas that need optimization, consider implementing this later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sweet
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
313283e
to
2a4befe
Compare
Signed-off-by: Jee Jee Li <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. We should set the default to oai triton_kernels once we upgrade to 2.8.0. That kernel outperforms this kernel when the batch size is large.
And please fix the load_weight
later.
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
WIP:
Test Plan
Test Result
On the local A800 machine, the generated results are as follows:
(Optional) Documentation Update