Skip to content

Conversation

@jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented Aug 15, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Reopen of #22538

Test Plan

lm eval

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args "pretrained=unsloth/gpt-oss-20b-BF16,max_model_len=32768" --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

benchmrk

# serve VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 vllm serve unsloth/gpt-oss-20b-BF16 # bench script vllm bench serve \ --backend vllm \ --model unsloth/gpt-oss-20b-BF16 \ --endpoint /v1/completions \ --dataset-name random \ --random-input 2048 \ --random-output 1024 \ --max-concurrency 10 \ --num-prompt 100 \ 

Test Result(A800)

This PR

vllm (pretrained=unsloth/gpt-oss-20b-BF16,max_model_len=32768,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3654|± |0.0133| | | |strict-match | 5|exact_match|↑ |0.2593|± |0.0121| ============ Serving Benchmark Result ============ Successful requests: 100 Maximum request concurrency: 10 Benchmark duration (s): 146.79 Total input tokens: 204483 Total generated tokens: 65236 Request throughput (req/s): 0.68 Output token throughput (tok/s): 444.42 Total Token throughput (tok/s): 1837.46 ---------------Time to First Token---------------- Mean TTFT (ms): 389.82 Median TTFT (ms): 178.96 P99 TTFT (ms): 2009.92 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 21.35 Median TPOT (ms): 21.27 P99 TPOT (ms): 25.40 ---------------Inter-token Latency---------------- Mean ITL (ms): 21.16 Median ITL (ms): 19.71 P99 ITL (ms): 137.77 ================================================== 

The main branch

vllm (pretrained=unsloth/gpt-oss-20b-BF16,max_model_len=32768,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3457|± |0.0131| | | |strict-match | 5|exact_match|↑ |0.2328|± |0.0116| ============ Serving Benchmark Result ============ Successful requests: 100 Maximum request concurrency: 10 Benchmark duration (s): 149.69 Total input tokens: 204483 Total generated tokens: 64990 Request throughput (req/s): 0.67 Output token throughput (tok/s): 434.16 Total Token throughput (tok/s): 1800.19 ---------------Time to First Token---------------- Mean TTFT (ms): 371.65 Median TTFT (ms): 188.81 P99 TTFT (ms): 2388.54 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 24.39 Median TPOT (ms): 21.64 P99 TPOT (ms): 41.43 ---------------Inter-token Latency---------------- Mean ITL (ms): 21.58 Median ITL (ms): 19.94 P99 ITL (ms): 147.82 ================================================== 

@mgoin The value of GSM8K is too low. We also tested using the transformers backend and found that the scores were similar.

 hf result |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3829|± |0.0134| | | |strict-match | 5|exact_match|↑ |0.2570|± |0.0120| 
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@jeejeelee jeejeelee marked this pull request as draft August 15, 2025 01:28
@mergify mergify bot added the gpt-oss Related to GPT-OSS models label Aug 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel for the swigluoai_and_mul activation function, used in GPT-OSS models. The changes include the kernel implementation, PyTorch bindings, and integration into the existing activation and fused MoE layers, along with corresponding tests.

My review has identified a few critical issues. The new CUDA kernel and its corresponding native PyTorch implementation for testing both incorrectly assume an interleaved memory layout, which is inconsistent with the rest of the codebase and will lead to incorrect behavior when integrated. Additionally, the PyTorch operator binding is missing default arguments, which will cause a runtime error in the fused MoE layer. I've provided specific suggestions to fix these issues.

Comment on lines +154 to +168
__global__ void swigluoai_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d, const float alpha, const float limit) {
const int64_t token_idx = blockIdx.x;
// TODO: Vectorize loads and stores.
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// gate = x[..., ::2] (even indices)
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
// up = x[..., 1::2] (odd indices)
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);

out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The memory access pattern in swigluoai_and_mul_kernel assumes an interleaved memory layout for gate and up values ([g0, u0, g1, u1, ...]). However, other _and_mul activation kernels in this file and their usage in fused_moe expect a concatenated layout ([g0, g1, ..., u0, u1, ...]). This inconsistency will lead to incorrect data being read and wrong computation results when used in fused_moe. The kernel should be updated to use a concatenated memory layout, similar to act_and_mul_kernel.

__global__ void swigluoai_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2 * d] const int d, const float alpha, const float limit) { const int64_t token_idx = blockIdx.x; // TODO: Vectorize loads and stores. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); } } 
Comment on lines 133 to 135
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha, float limit) "
"-> ()");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The swigluoai_and_mul operator is called from fused_moe.py without the alpha and limit arguments, relying on default values. However, the operator definition here does not specify any default values. This will lead to a runtime error due to a mismatch in the number of arguments. Please add the default values to the operator definition to match the C++ function signature.

Suggested change
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha, float limit) "
"-> ()");
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float limit=7.0) "
"-> ()");
Comment on lines +250 to +258
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""

gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The forward_native implementation assumes an interleaved memory layout for gate and up tensors by using x[..., ::2] and x[..., 1::2]. This is inconsistent with other similar activation functions in vLLM which use a concatenated layout. To ensure correctness and consistency, especially for testing against the CUDA kernel (which should also be updated), this should be changed to split the tensor along the last dimension.

Suggested change
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
gate, up = x[..., :d], x[..., d:]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
elif activation == "swiglu_oai":
elif activation == "swigluoai":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo Compared to the previous #22538, the main modification is here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace the torch compile op here? Hopefully it is faster :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already replaced. I will spend time continuing to optimize this cuda kernel later, so I haven't conducted testing now

@jeejeelee jeejeelee marked this pull request as ready for review August 15, 2025 02:13
@jeejeelee
Copy link
Collaborator Author

@mgoin Would you mind taking another look at this PR? This is a reopening of #22538, which addresses the issue where activation function names error after merging #22428

torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))
elif activation == "swiglu_oai":
elif activation == "swigluoai":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace the torch compile op here? Hopefully it is faster :)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee force-pushed the support-oss-activation branch from 6294d3b to da30292 Compare August 16, 2025 03:56
@jeejeelee jeejeelee requested a review from mgoin August 16, 2025 03:57
@mgoin mgoin added kernel ready ONLY add when PR is ready to merge/full CI is needed labels Aug 16, 2025
@mgoin mgoin enabled auto-merge (squash) August 16, 2025 16:29
@mgoin mgoin merged commit 4d4061b into vllm-project:main Aug 17, 2025
72 checks passed
@jeejeelee jeejeelee deleted the support-oss-activation branch August 17, 2025 10:17
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
cyang49 pushed a commit to cyang49/vllm that referenced this pull request Aug 20, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models kernel ready ONLY add when PR is ready to merge/full CI is needed

2 participants