Skip to content

Commit 6fce752

Browse files
committed
[bugfix] fix NaN lse of _correct_attn_cp_out_kernel
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
1 parent b2fad26 commit 6fce752

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

vllm/attention/ops/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
5353
lse = tl.load(lses_ptr + lse_offsets)
5454
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
5555
lse_max = tl.max(lse, axis=0)
56+
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
5657
lse -= lse_max
5758
lse_exp = tl.exp(lse)
5859
lse_acc = tl.sum(lse_exp, axis=0)

0 commit comments

Comments
 (0)