Skip to content

Commit c4768dc

Browse files
authored
[Kernel] Fix fused_gdn_gating (vllm-project#28343)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent a65a934 commit c4768dc

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

vllm/model_executor/models/qwen3_next.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,8 +1367,10 @@ def fused_gdn_gating_kernel(
13671367
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
13681368
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
13691369
# compute beta_output = sigmoid(b)
1370-
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
1371-
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
1370+
blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
1371+
tl.store(
1372+
beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
1373+
)
13721374

13731375

13741376
def fused_gdn_gating(
@@ -1389,7 +1391,7 @@ def fused_gdn_gating(
13891391
seq_len = 1
13901392
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
13911393
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
1392-
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
1394+
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
13931395
fused_gdn_gating_kernel[grid](
13941396
g,
13951397
beta_output,

0 commit comments

Comments
 (0)