-
- Notifications
You must be signed in to change notification settings - Fork 11.1k
[Kernel] Add cuda kernel for gpt_oss activation #22951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
| 👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new CUDA kernel for the swigluoai_and_mul activation function, used in GPT-OSS models. The changes include the kernel implementation, PyTorch bindings, and integration into the existing activation and fused MoE layers, along with corresponding tests.
My review has identified a few critical issues. The new CUDA kernel and its corresponding native PyTorch implementation for testing both incorrectly assume an interleaved memory layout, which is inconsistent with the rest of the codebase and will lead to incorrect behavior when integrated. Additionally, the PyTorch operator binding is missing default arguments, which will cause a runtime error in the fused MoE layer. I've provided specific suggestions to fix these issues.
| __global__ void swigluoai_and_mul_kernel( | ||
| scalar_t* __restrict__ out, // [..., d] | ||
| const scalar_t* __restrict__ input, // [..., 2, d] | ||
| const int d, const float alpha, const float limit) { | ||
| const int64_t token_idx = blockIdx.x; | ||
| // TODO: Vectorize loads and stores. | ||
| for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { | ||
| // gate = x[..., ::2] (even indices) | ||
| const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]); | ||
| // up = x[..., 1::2] (odd indices) | ||
| const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]); | ||
| | ||
| out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memory access pattern in swigluoai_and_mul_kernel assumes an interleaved memory layout for gate and up values ([g0, u0, g1, u1, ...]). However, other _and_mul activation kernels in this file and their usage in fused_moe expect a concatenated layout ([g0, g1, ..., u0, u1, ...]). This inconsistency will lead to incorrect data being read and wrong computation results when used in fused_moe. The kernel should be updated to use a concatenated memory layout, similar to act_and_mul_kernel.
__global__ void swigluoai_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2 * d] const int d, const float alpha, const float limit) { const int64_t token_idx = blockIdx.x; // TODO: Vectorize loads and stores. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit); } } | ops.def( | ||
| "swigluoai_and_mul(Tensor! out, Tensor input, float alpha, float limit) " | ||
| "-> ()"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The swigluoai_and_mul operator is called from fused_moe.py without the alpha and limit arguments, relying on default values. However, the operator definition here does not specify any default values. This will lead to a runtime error due to a mismatch in the number of arguments. Please add the default values to the operator definition to match the C++ function signature.
| ops.def( | |
| "swigluoai_and_mul(Tensor! out, Tensor input, float alpha, float limit) " | |
| "-> ()"); | |
| ops.def( | |
| "swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float limit=7.0) " | |
| "-> ()"); |
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | ||
| """PyTorch-native implementation equivalent to forward().""" | ||
| | ||
| gate, up = x[..., ::2], x[..., 1::2] | ||
| gate = gate.clamp(min=None, max=self.limit) | ||
| up = up.clamp(min=-self.limit, max=self.limit) | ||
| glu = gate * torch.sigmoid(gate * self.alpha) | ||
| gated_output = (up + 1) * glu | ||
| return gated_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward_native implementation assumes an interleaved memory layout for gate and up tensors by using x[..., ::2] and x[..., 1::2]. This is inconsistent with other similar activation functions in vLLM which use a concatenated layout. To ensure correctness and consistency, especially for testing against the CUDA kernel (which should also be updated), this should be changed to split the tensor along the last dimension.
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | |
| """PyTorch-native implementation equivalent to forward().""" | |
| gate, up = x[..., ::2], x[..., 1::2] | |
| gate = gate.clamp(min=None, max=self.limit) | |
| up = up.clamp(min=-self.limit, max=self.limit) | |
| glu = gate * torch.sigmoid(gate * self.alpha) | |
| gated_output = (up + 1) * glu | |
| return gated_output | |
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | |
| """PyTorch-native implementation equivalent to forward().""" | |
| d = x.shape[-1] // 2 | |
| gate, up = x[..., :d], x[..., d:] | |
| gate = gate.clamp(min=None, max=self.limit) | |
| up = up.clamp(min=-self.limit, max=self.limit) | |
| glu = gate * torch.sigmoid(gate * self.alpha) | |
| gated_output = (up + 1) * glu | |
| return gated_output |
| torch.ops._C.silu_and_mul(intermediate_cache2, | ||
| intermediate_cache1.view(-1, 2 * N)) | ||
| elif activation == "swiglu_oai": | ||
| elif activation == "swigluoai": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we replace the torch compile op here? Hopefully it is faster :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already replaced. I will spend time continuing to optimize this cuda kernel later, so I haven't conducted testing now
| torch.ops._C.silu_and_mul(intermediate_cache2, | ||
| intermediate_cache1.view(-1, 2 * N)) | ||
| elif activation == "swiglu_oai": | ||
| elif activation == "swigluoai": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we replace the torch compile op here? Hopefully it is faster :)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
6294d3b to da30292 Compare Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
Reopen of #22538
Test Plan
lm eval
VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args "pretrained=unsloth/gpt-oss-20b-BF16,max_model_len=32768" --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size autobenchmrk
Test Result(A800)
This PR
The main branch
@mgoin The value of GSM8K is too low. We also tested using the transformers backend and found that the scores were similar.