Skip to content

[oss] Init gpt-oss bf16 support #22508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 10, 2025
Merged

Conversation

jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented Aug 8, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

WIP:

  • Currently only supports eager mode
  • Support TP>1

Test Plan

import os
from vllm import LLM, SamplingParams

os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN_VLLM_V1"
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0 8.6"


def main():
    prompts = [
        "<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-08-08\n\nReasoning: medium\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.\nCalls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>Who are you..<|end|><|start|>assistant"  # noqa: E501
    ]
    # Create a sampling params object.
    sampling_params = SamplingParams(
        temperature=0.8, top_p=0.95, max_tokens=1024
    )
    llm = LLM(
        model="unsloth/gpt-oss-20b-BF16",
        enforce_eager=True,
    )
    # Generate texts from the prompts.
    # The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    print("\nGenerated Outputs:\n" + "-" * 60)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt:    {prompt!r}")
        print(f"Output:    {generated_text!r}")
        print("-" * 60)


if __name__ == "__main__":
    main()

Test Result

On the local A800 machine, the generated results are as follows:

Generated Outputs:
------------------------------------------------------------
Prompt:    "<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-08-08\n\nReasoning: medium\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.\nCalls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>user<|message|>Who are you.<|end|><|start|>assistant"
Output:    'analysisThe user asks "Who are you." We need to respond in a helpful and concise manner. According to the style guidelines: "Do not mention policies or guidelines. Do not mention limitations. Use the style guidelines." We should identify ourselves as ChatGPT, an AI language model developed by OpenAI. Provide context about capabilities: text generation, answer questions, etc. Keep it concise, no fluff. Also no mention of policies. We can add that we can answer questions, provide information, etc. Also maybe mention that we don\'t have personal identity. So produce a short answer.assistantfinalI’m ChatGPT, a conversational AI developed by\u202fOpenAI. I generate responses to questions, offer explanations, help with creative writing, coding, and many other text‑based tasks. I don’t have personal experiences or feelings—just knowledge and language‑processing abilities.'

(Optional) Documentation Update

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the gpt-oss model with bf16 precision. The changes are extensive, touching fused MoE kernels, layer configurations, and adding a new model definition file. My review has identified a couple of issues. Firstly, in vllm/model_executor/layers/fused_moe/layer.py, the data type for MoE biases is hardcoded to torch.bfloat16, which should be parameterized to support other dtypes. Secondly, there's a critical bug in the weight loading logic for down_proj in vllm/model_executor/models/gpt_oss.py, where an incorrect permutation is applied. I've provided suggestions to fix these issues.

Comment on lines +544 to +547
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]

param.copy_(narrow_weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The down_proj weight is being incorrectly permuted. The down_proj layer is a row-parallel layer, and its weight in vLLM has the shape (num_experts, hidden_size, intermediate_size_per_partition). After sharding, narrow_weight already has this shape. The permute(0, 2, 1) operation incorrectly changes it to (num_experts, intermediate_size_per_partition, hidden_size), which will cause shape mismatches and incorrect computations. The permutation should be removed.

Suggested change
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
param = params_dict[new_name]
param.copy_(narrow_weight)

Copy link

github-actions bot commented Aug 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Comment on lines +477 to +482
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use WeightsMapper here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree, including the qkv mapping, planning to complete it in a subsequent PR

@mergify mergify bot added the performance Performance-related issues label Aug 8, 2025
Comment on lines 1688 to 1689
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check the usage of fused_moe, there are still tests and models (bert_with_rope.py, deepseek.py and minicpm.py) using this function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have checked as thoroughly as possible, but there may still be omissions

@mergify mergify bot added the deepseek Related to DeepSeek models label Aug 8, 2025
@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 8, 2025
@jeejeelee
Copy link
Collaborator Author

Add ready and see if it introduces any new bugs

def oss_act(gate_up):
alpha = 1.702
limit = 7.0
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this implementation come from gpt-oss repo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -1425,6 +1451,8 @@ def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_bias: Optional[torch.Tensor],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got this question from others. Can we fit bias into _zp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be able to, I'm not sure whether merging them is reasonable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnellnm thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think using the existing _zp arguments would be best if possible. They are currently only used by the existing triton implementation for int4/int8 and are ignored otherwise.

This may require flipping the sign since it looks like the existing ZP are subtracted rather than added.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, It's too hard to maintain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having two different sets of arguments which are basically the same thing will be confusing. As a user I won't know whether to use _zp or _bias.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that although they are similar, the semantics are different. Merging them might mislead users. At least in this PR, I don't want to implement this - I personally don't like it

@@ -160,7 +160,9 @@ def __init__(
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False)
apply_router_weight_on_input=False,
has_bias=True,
Copy link
Member

@zyongye zyongye Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need it here? Can it be detect inside this class? Like if layer.w1_bias is not None? Is doesn't seem as a clean solution.

Edit: Actually this is cleaner than not showing bias. Thanks for adding it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I haven't misunderstood your question, whether to have bias should be decided at the model level. If there's no bias, we don't need to register bias parameters, similar to linear layer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I just realized that as well. Previously, I just hard-coded into mxfp4 class. Thanks for adding it

narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]

param.copy_(narrow_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why can't we use FusedMoE.weight_loader here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be possible to implement following llama4's approach, I will optimize in later PRs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I didn't realize llama4 uses the same approach to load all experts at once. Yes we can optimize this later.

"embed_tokens": "embedding",
}

def maybe_rename(name: str) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire function has a lot of duplicates with mxfp4 one? Is there any chance we can combine them to one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can optimize this using WeightsMapper. This model's weight loader has many areas that need optimization, consider implementing this later

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweet

@jeejeelee jeejeelee requested a review from zyongye August 8, 2025 16:42
Copy link

mergify bot commented Aug 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jeejeelee.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 9, 2025
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
@mergify mergify bot removed the needs-rebase label Aug 9, 2025
Copy link
Member

@zyongye zyongye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We should set the default to oai triton_kernels once we upgrade to 2.8.0. That kernel outperforms this kernel when the batch size is large.

And please fix the load_weight later.

@vllm-bot vllm-bot merged commit 0c5254b into vllm-project:main Aug 10, 2025
37 of 46 checks passed
@jeejeelee jeejeelee deleted the oss-support-bf16 branch August 10, 2025 04:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants