Skip to content

Commit 76e4dcf

Browse files
authored
[Misc] Remove unused attention prefix prefill ops functions (vllm-project#26971)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
1 parent d5edcb8 commit 76e4dcf

File tree

2 files changed

+0
-213
lines changed

2 files changed

+0
-213
lines changed

vllm/attention/ops/prefix_prefill.py

Lines changed: 0 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -335,216 +335,6 @@ def _fwd_kernel(
335335
return
336336

337337

338-
@triton.jit
339-
def _fwd_kernel_flash_attn_v2(
340-
Q,
341-
K,
342-
V,
343-
K_cache,
344-
V_cache,
345-
B_Loc,
346-
sm_scale,
347-
B_Start_Loc,
348-
B_Seqlen,
349-
B_Ctxlen,
350-
block_size,
351-
x,
352-
Out,
353-
stride_b_loc_b,
354-
stride_b_loc_s,
355-
stride_qbs,
356-
stride_qh,
357-
stride_qd,
358-
stride_kbs,
359-
stride_kh,
360-
stride_kd,
361-
stride_vbs,
362-
stride_vh,
363-
stride_vd,
364-
stride_obs,
365-
stride_oh,
366-
stride_od,
367-
stride_k_cache_bs,
368-
stride_k_cache_h,
369-
stride_k_cache_d,
370-
stride_k_cache_bl,
371-
stride_k_cache_x,
372-
stride_v_cache_bs,
373-
stride_v_cache_h,
374-
stride_v_cache_d,
375-
stride_v_cache_bl,
376-
num_queries_per_kv: int,
377-
BLOCK_M: tl.constexpr,
378-
BLOCK_DMODEL: tl.constexpr,
379-
BLOCK_N: tl.constexpr,
380-
):
381-
cur_batch = tl.program_id(0)
382-
cur_head = tl.program_id(1)
383-
start_m = tl.program_id(2)
384-
385-
cur_kv_head = cur_head // num_queries_per_kv
386-
387-
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
388-
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
389-
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
390-
391-
block_start_loc = BLOCK_M * start_m
392-
393-
# initialize offsets
394-
offs_n = tl.arange(0, BLOCK_N)
395-
offs_d = tl.arange(0, BLOCK_DMODEL)
396-
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
397-
off_q = (
398-
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
399-
+ cur_head * stride_qh
400-
+ offs_d[None, :] * stride_qd
401-
)
402-
403-
q = tl.load(
404-
Q + off_q,
405-
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
406-
other=0.0,
407-
)
408-
409-
# # initialize pointer to m and l
410-
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
411-
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
412-
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
413-
414-
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
415-
start_n = tl.multiple_of(start_n, BLOCK_N)
416-
# -- compute qk ----
417-
bn = tl.load(
418-
B_Loc
419-
+ cur_batch * stride_b_loc_b
420-
+ ((start_n + offs_n) // block_size) * stride_b_loc_s,
421-
mask=(start_n + offs_n) < cur_batch_ctx_len,
422-
other=0,
423-
).to(tl.int64)
424-
off_k = (
425-
bn[None, :] * stride_k_cache_bs
426-
+ cur_kv_head * stride_k_cache_h
427-
+ (offs_d[:, None] // x) * stride_k_cache_d
428-
+ ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl
429-
+ (offs_d[:, None] % x) * stride_k_cache_x
430-
)
431-
off_v = (
432-
bn[:, None] * stride_v_cache_bs
433-
+ cur_kv_head * stride_v_cache_h
434-
+ offs_d[None, :] * stride_v_cache_d
435-
+ (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl
436-
)
437-
k = tl.load(
438-
K_cache + off_k,
439-
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
440-
other=0.0,
441-
)
442-
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
443-
qk += tl.dot(q, k)
444-
qk = tl.where(
445-
(start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
446-
)
447-
qk *= sm_scale
448-
449-
# -- compute m_ij, p, l_ij
450-
m_ij = tl.max(qk, 1)
451-
m_i_new = tl.maximum(m_i, m_ij)
452-
p = tl.math.exp(qk - m_i_new[:, None])
453-
l_ij = tl.sum(p, 1)
454-
# -- update m_i and l_i
455-
456-
alpha = tl.math.exp(m_i - m_i_new)
457-
l_i_new = alpha * l_i + l_ij
458-
# -- update output accumulator --
459-
# scale p
460-
# scale acc
461-
acc_scale = alpha
462-
# acc_scale = l_i / l_i_new * alpha
463-
acc = acc * acc_scale[:, None]
464-
# update acc
465-
v = tl.load(
466-
V_cache + off_v,
467-
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
468-
other=0.0,
469-
)
470-
471-
p = p.to(v.dtype)
472-
acc += tl.dot(p, v)
473-
# update m_i and l_i
474-
l_i = l_i_new
475-
m_i = m_i_new
476-
477-
off_k = (
478-
offs_n[None, :] * stride_kbs
479-
+ cur_kv_head * stride_kh
480-
+ offs_d[:, None] * stride_kd
481-
)
482-
off_v = (
483-
offs_n[:, None] * stride_vbs
484-
+ cur_kv_head * stride_vh
485-
+ offs_d[None, :] * stride_vd
486-
)
487-
k_ptrs = K + off_k
488-
v_ptrs = V + off_v
489-
490-
block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
491-
492-
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
493-
start_n = tl.multiple_of(start_n, BLOCK_N)
494-
# -- compute qk ----
495-
k = tl.load(
496-
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
497-
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len,
498-
other=0.0,
499-
)
500-
501-
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
502-
qk += tl.dot(q, k)
503-
qk *= sm_scale
504-
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
505-
506-
# -- compute m_ij, p, l_ij
507-
m_ij = tl.max(qk, 1)
508-
m_i_new = tl.maximum(m_i, m_ij)
509-
p = tl.math.exp(qk - m_i_new[:, None])
510-
l_ij = tl.sum(p, 1)
511-
# -- update m_i and l_i
512-
513-
alpha = tl.math.exp(m_i - m_i_new)
514-
l_i_new = alpha * l_i + l_ij
515-
# -- update output accumulator --
516-
# scale p
517-
# scale acc
518-
acc_scale = alpha
519-
# acc_scale = l_i / l_i_new * alpha
520-
acc = acc * acc_scale[:, None]
521-
# update acc
522-
v = tl.load(
523-
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
524-
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len,
525-
other=0.0,
526-
)
527-
528-
p = p.to(v.dtype)
529-
acc += tl.dot(p, v)
530-
# update m_i and l_i
531-
l_i = l_i_new
532-
m_i = m_i_new
533-
534-
# acc /= l_i[:, None]
535-
# initialize pointers to output
536-
off_o = (
537-
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
538-
+ cur_head * stride_oh
539-
+ offs_d[None, :] * stride_od
540-
)
541-
out_ptrs = Out + off_o
542-
tl.store(
543-
out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len
544-
)
545-
return
546-
547-
548338
@triton.jit
549339
def _fwd_kernel_alibi(
550340
Q,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ class GPTQMarlinState(Enum):
9898

9999

100100
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
101-
def __init_(self, moe: FusedMoEConfig):
102-
super().__init__(moe)
103-
104101
@staticmethod
105102
def get_moe_method(
106103
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501

0 commit comments

Comments
 (0)