Skip to content

Conversation

@jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented Aug 8, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Test Plan

lm eval

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args "pretrained=unsloth/gpt-oss-20b-BF16,max_model_len=32768" --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

benchmrk

# serve VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 vllm serve unsloth/gpt-oss-20b-BF16 # bench script vllm bench serve \ --backend vllm \ --model unsloth/gpt-oss-20b-BF16 \ --endpoint /v1/completions \ --dataset-name random \ --random-input 2048 \ --random-output 1024 \ --max-concurrency 10 \ --num-prompt 100 \ 

Test Result(A800)

This PR

vllm (pretrained=unsloth/gpt-oss-20b-BF16,max_model_len=32768,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3654|± |0.0133| | | |strict-match | 5|exact_match|↑ |0.2593|± |0.0121| ============ Serving Benchmark Result ============ Successful requests: 100 Maximum request concurrency: 10 Benchmark duration (s): 146.79 Total input tokens: 204483 Total generated tokens: 65236 Request throughput (req/s): 0.68 Output token throughput (tok/s): 444.42 Total Token throughput (tok/s): 1837.46 ---------------Time to First Token---------------- Mean TTFT (ms): 389.82 Median TTFT (ms): 178.96 P99 TTFT (ms): 2009.92 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 21.35 Median TPOT (ms): 21.27 P99 TPOT (ms): 25.40 ---------------Inter-token Latency---------------- Mean ITL (ms): 21.16 Median ITL (ms): 19.71 P99 ITL (ms): 137.77 ================================================== 

The main branch

vllm (pretrained=unsloth/gpt-oss-20b-BF16,max_model_len=32768,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3457|± |0.0131| | | |strict-match | 5|exact_match|↑ |0.2328|± |0.0116| ============ Serving Benchmark Result ============ Successful requests: 100 Maximum request concurrency: 10 Benchmark duration (s): 149.69 Total input tokens: 204483 Total generated tokens: 64990 Request throughput (req/s): 0.67 Output token throughput (tok/s): 434.16 Total Token throughput (tok/s): 1800.19 ---------------Time to First Token---------------- Mean TTFT (ms): 371.65 Median TTFT (ms): 188.81 P99 TTFT (ms): 2388.54 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 24.39 Median TPOT (ms): 21.64 P99 TPOT (ms): 41.43 ---------------Inter-token Latency---------------- Mean ITL (ms): 21.58 Median ITL (ms): 19.94 P99 ITL (ms): 147.82 ================================================== 

@mgoin The value of GSM8K is too low. We also tested using the transformers backend and found that the scores were similar.

 hf result |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3829|± |0.0134| | | |strict-match | 5|exact_match|↑ |0.2570|± |0.0120| 
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee marked this pull request as draft August 8, 2025 18:28
@github-actions
Copy link

github-actions bot commented Aug 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel for the swiglu_oai activation function, which is used in gpt_oss models. The changes span the CUDA kernel implementation, PyTorch bindings, a Python wrapper class, and corresponding tests. My review has identified a critical issue in the Python implementation where the data layout assumption is inconsistent with the CUDA kernel, and a high-severity bug in the test logic that inadvertently breaks existing tests for another activation function.

Comment on lines +253 to +258
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward_native implementation for SwiGLUOAI assumes an interleaved layout for gate and up tensors (x[..., ::2], x[..., 1::2]). However, the corresponding CUDA kernel clamp_swiglu_kernel_with_params expects a concatenated layout (first half is gate, second half is up), similar to other activation functions in this file like SiluAndMul. This inconsistency will lead to incorrect results when using the native PyTorch path and will cause the tests to fail. The implementation should be updated to match the CUDA kernel's expectation.

Suggested change
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
d = x.shape[-1] // 2
gate, up = x[..., :d], x[..., d:]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
return (up + 1) * glu
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, but we should fix the kernel rather than this part

Comment on lines 85 to 90
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
if activation == "swiglu_oai":
opcheck(fn, (out, x, layer.alpha, layer.limit))
else:
opcheck(fn, (out, x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The conditional logic for opcheck is incorrect. By adding a new if statement instead of an elif, you've broken the logic for the fatrelu activation. When activation == "fatrelu", opcheck is now called twice: once correctly, and a second time inside the final else block with the wrong number of arguments, which will cause the test to fail. This should be a single if/elif/else chain.

Suggested change
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
if activation == "swiglu_oai":
opcheck(fn, (out, x, layer.alpha, layer.limit))
else:
opcheck(fn, (out, x))
if activation == "fatrelu":
opcheck(fn, (out, x, threshold))
elif activation == "swiglu_oai":
opcheck(fn, (out, x, layer.alpha, layer.limit))
else:
opcheck(fn, (out, x))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catchs

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee
Copy link
Collaborator Author

@WoosukKwon @zyongye It looks like if we can use triton kernel, is it still necessary to add cuda kernel?

@zyongye
Copy link
Member

zyongye commented Aug 9, 2025

That's right. The triton kernel should support this, but I guess it can use that only for SM80+?

@mgoin mgoin self-assigned this Aug 10, 2025
@mergify mergify bot added the gpt-oss Related to GPT-OSS models label Aug 11, 2025


@CustomOp.register("swiglu_oai")
class SwiGLUOAI(CustomOp):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this should be a SwigluOAIAndMul. Don't forget to put it in the registry at the bottom

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee marked this pull request as ready for review August 12, 2025 09:35
@jeejeelee jeejeelee requested a review from mgoin August 12, 2025 09:35
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a few nits

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee requested a review from yewentao256 as a code owner August 13, 2025 02:42
@jeejeelee jeejeelee requested a review from mgoin August 13, 2025 02:43
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@mgoin mgoin enabled auto-merge (squash) August 13, 2025 17:12
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 13, 2025
@zyongye
Copy link
Member

zyongye commented Aug 13, 2025

could you change the name here as well?

or scoring_func != "softmax" or activation != "swiglu_oai"

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee
Copy link
Collaborator Author

could you change the name here as well?

or scoring_func != "softmax" or activation != "swiglu_oai"

Done in 8471138

@simon-mo simon-mo disabled auto-merge August 15, 2025 00:06
@simon-mo simon-mo merged commit 81f4b96 into vllm-project:main Aug 15, 2025
68 of 73 checks passed
simon-mo added a commit that referenced this pull request Aug 15, 2025
@simon-mo
Copy link
Collaborator

@jeejeelee please re-open, I ran into issue after merge, see #22948 (comment)

@jeejeelee jeejeelee deleted the support-oss-activation branch August 15, 2025 01:22
@jeejeelee
Copy link
Collaborator Author

the merge order case, I'll try to resolve it.

yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed

4 participants