-
- Notifications
You must be signed in to change notification settings - Fork 11.1k
[Kernel] Add cuda kernel for gpt_oss activation #22951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9abad5a 14462e1 c5ba5ee d2fcd71 7d7f7a3 27bf18a 99c9f56 8471138 3eee9f0 c1fb96a 4e15e4c 2404967 0a07d9d d8f3186 da30292 22f58d0 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| | @@ -239,6 +239,35 @@ def extra_repr(self) -> str: | |||||||||||||||||||||||||||||||||||||
| return f'approximate={repr(self.approximate)}' | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| @CustomOp.register("swigluoai_and_mul") | ||||||||||||||||||||||||||||||||||||||
| class SwigluOAIAndMul(CustomOp): | ||||||||||||||||||||||||||||||||||||||
| # https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110 | ||||||||||||||||||||||||||||||||||||||
| def __init__(self, alpha: float = 1.702, limit: float = 7.0): | ||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||
| self.alpha = alpha | ||||||||||||||||||||||||||||||||||||||
| self.limit = limit | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| def forward_native(self, x: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||
| """PyTorch-native implementation equivalent to forward().""" | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| gate, up = x[..., ::2], x[..., 1::2] | ||||||||||||||||||||||||||||||||||||||
| gate = gate.clamp(min=None, max=self.limit) | ||||||||||||||||||||||||||||||||||||||
| up = up.clamp(min=-self.limit, max=self.limit) | ||||||||||||||||||||||||||||||||||||||
| glu = gate * torch.sigmoid(gate * self.alpha) | ||||||||||||||||||||||||||||||||||||||
| gated_output = (up + 1) * glu | ||||||||||||||||||||||||||||||||||||||
| return gated_output | ||||||||||||||||||||||||||||||||||||||
| Comment on lines +250 to +258 Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Suggested change
| ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||
| d = x.shape[-1] // 2 | ||||||||||||||||||||||||||||||||||||||
| output_shape = (x.shape[:-1] + (d, )) | ||||||||||||||||||||||||||||||||||||||
| out = torch.empty(output_shape, dtype=x.dtype, device=x.device) | ||||||||||||||||||||||||||||||||||||||
| torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit) | ||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| def extra_repr(self) -> str: | ||||||||||||||||||||||||||||||||||||||
| return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}" | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| @CustomOp.register("gelu_new") | ||||||||||||||||||||||||||||||||||||||
| class NewGELU(CustomOp): | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | @@ -330,6 +359,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: | |||||||||||||||||||||||||||||||||||||
| return torch.square(F.relu(x)) | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||
| #TODO : implement cuda kenrels | ||||||||||||||||||||||||||||||||||||||
| return self.forward_native(x) | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | @@ -406,9 +436,14 @@ def get_act_fn(act_fn_name: str) -> nn.Module: | |||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ | ||||||||||||||||||||||||||||||||||||||
| "gelu": lambda: GeluAndMul(), | ||||||||||||||||||||||||||||||||||||||
| "silu": lambda: SiluAndMul(), | ||||||||||||||||||||||||||||||||||||||
| "geglu": lambda: GeluAndMul(), | ||||||||||||||||||||||||||||||||||||||
| "gelu": | ||||||||||||||||||||||||||||||||||||||
| lambda: GeluAndMul(), | ||||||||||||||||||||||||||||||||||||||
| "silu": | ||||||||||||||||||||||||||||||||||||||
| lambda: SiluAndMul(), | ||||||||||||||||||||||||||||||||||||||
| "geglu": | ||||||||||||||||||||||||||||||||||||||
| lambda: GeluAndMul(), | ||||||||||||||||||||||||||||||||||||||
| "swigluoai": | ||||||||||||||||||||||||||||||||||||||
| lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), | ||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| | ||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -169,25 +169,13 @@ def fused_marlin_moe(hidden_states: torch.Tensor, | |
| if activation == "silu": | ||
| torch.ops._C.silu_and_mul(intermediate_cache2, | ||
| intermediate_cache1.view(-1, 2 * N)) | ||
| elif activation == "swiglu_oai": | ||
| # NOTE: in gpt-oss, the gate_proj and up_proj is interleaved | ||
| # - interleaved: gate, up = gate_up[..., ::2], gate_up[..., 1::2] | ||
| # - origin: gate, up = gate_up[..., :N], gate_up[..., N:] | ||
| | ||
| @torch.compile(dynamic=True) | ||
| def swiglu_oai(gate_up): | ||
| alpha = 1.702 | ||
| limit = 7.0 | ||
| gate, up = gate_up[..., ::2], gate_up[..., 1::2] | ||
| gate = gate.clamp(min=None, max=limit) | ||
| up = up.clamp(min=-limit, max=limit) | ||
| glu = gate * torch.sigmoid(gate * alpha) | ||
| return (up + 1) * glu | ||
| | ||
| intermediate_cache2 = swiglu_oai(intermediate_cache1) | ||
| elif activation == "swigluoai": | ||
| Collaborator Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Member There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we replace the torch compile op here? Hopefully it is faster :) Collaborator Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already replaced. I will spend time continuing to optimize this cuda kernel later, so I haven't conducted testing now | ||
| # alpha = 1.702, limit = 7.0 | ||
| torch.ops._C.swigluoai_and_mul(intermediate_cache2, | ||
| intermediate_cache1.view(-1, 2 * N)) | ||
| else: | ||
| raise ValueError(f"Unsupported activation: {activation}. " | ||
| "Only silu and swiglu_oai activations are supported.") | ||
| "Only silu and swigluoai activations are supported.") | ||
| | ||
| if expert_map is not None: | ||
| intermediate_cache3.zero_() | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memory access pattern in
swigluoai_and_mul_kernelassumes an interleaved memory layout forgateandupvalues ([g0, u0, g1, u1, ...]). However, other_and_mulactivation kernels in this file and their usage infused_moeexpect a concatenated layout ([g0, g1, ..., u0, u1, ...]). This inconsistency will lead to incorrect data being read and wrong computation results when used infused_moe. The kernel should be updated to use a concatenated memory layout, similar toact_and_mul_kernel.