There was an error while loading. Please reload this page.
1 parent d7b5ae7 commit 6ddf209Copy full SHA for 6ddf209
vllm/attention/ops/common.py
@@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
53
lse = tl.load(lses_ptr + lse_offsets)
54
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
55
lse_max = tl.max(lse, axis=0)
56
+ lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
57
lse -= lse_max
58
lse_exp = tl.exp(lse)
59
lse_acc = tl.sum(lse_exp, axis=0)
0 commit comments