@@ -335,216 +335,6 @@ def _fwd_kernel(
335335 return
336336
337337
338- @triton .jit
339- def _fwd_kernel_flash_attn_v2 (
340- Q ,
341- K ,
342- V ,
343- K_cache ,
344- V_cache ,
345- B_Loc ,
346- sm_scale ,
347- B_Start_Loc ,
348- B_Seqlen ,
349- B_Ctxlen ,
350- block_size ,
351- x ,
352- Out ,
353- stride_b_loc_b ,
354- stride_b_loc_s ,
355- stride_qbs ,
356- stride_qh ,
357- stride_qd ,
358- stride_kbs ,
359- stride_kh ,
360- stride_kd ,
361- stride_vbs ,
362- stride_vh ,
363- stride_vd ,
364- stride_obs ,
365- stride_oh ,
366- stride_od ,
367- stride_k_cache_bs ,
368- stride_k_cache_h ,
369- stride_k_cache_d ,
370- stride_k_cache_bl ,
371- stride_k_cache_x ,
372- stride_v_cache_bs ,
373- stride_v_cache_h ,
374- stride_v_cache_d ,
375- stride_v_cache_bl ,
376- num_queries_per_kv : int ,
377- BLOCK_M : tl .constexpr ,
378- BLOCK_DMODEL : tl .constexpr ,
379- BLOCK_N : tl .constexpr ,
380- ):
381- cur_batch = tl .program_id (0 )
382- cur_head = tl .program_id (1 )
383- start_m = tl .program_id (2 )
384-
385- cur_kv_head = cur_head // num_queries_per_kv
386-
387- cur_batch_ctx_len = tl .load (B_Ctxlen + cur_batch )
388- cur_batch_seq_len = tl .load (B_Seqlen + cur_batch )
389- cur_batch_in_all_start_index = tl .load (B_Start_Loc + cur_batch )
390-
391- block_start_loc = BLOCK_M * start_m
392-
393- # initialize offsets
394- offs_n = tl .arange (0 , BLOCK_N )
395- offs_d = tl .arange (0 , BLOCK_DMODEL )
396- offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
397- off_q = (
398- (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_qbs
399- + cur_head * stride_qh
400- + offs_d [None , :] * stride_qd
401- )
402-
403- q = tl .load (
404- Q + off_q ,
405- mask = offs_m [:, None ] < cur_batch_seq_len - cur_batch_ctx_len ,
406- other = 0.0 ,
407- )
408-
409- # # initialize pointer to m and l
410- m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
411- l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 )
412- acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL ], dtype = tl .float32 )
413-
414- for start_n in range (0 , cur_batch_ctx_len , BLOCK_N ):
415- start_n = tl .multiple_of (start_n , BLOCK_N )
416- # -- compute qk ----
417- bn = tl .load (
418- B_Loc
419- + cur_batch * stride_b_loc_b
420- + ((start_n + offs_n ) // block_size ) * stride_b_loc_s ,
421- mask = (start_n + offs_n ) < cur_batch_ctx_len ,
422- other = 0 ,
423- ).to (tl .int64 )
424- off_k = (
425- bn [None , :] * stride_k_cache_bs
426- + cur_kv_head * stride_k_cache_h
427- + (offs_d [:, None ] // x ) * stride_k_cache_d
428- + ((start_n + offs_n [None , :]) % block_size ) * stride_k_cache_bl
429- + (offs_d [:, None ] % x ) * stride_k_cache_x
430- )
431- off_v = (
432- bn [:, None ] * stride_v_cache_bs
433- + cur_kv_head * stride_v_cache_h
434- + offs_d [None , :] * stride_v_cache_d
435- + (start_n + offs_n [:, None ]) % block_size * stride_v_cache_bl
436- )
437- k = tl .load (
438- K_cache + off_k ,
439- mask = (start_n + offs_n [None , :]) < cur_batch_ctx_len ,
440- other = 0.0 ,
441- )
442- qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
443- qk += tl .dot (q , k )
444- qk = tl .where (
445- (start_n + offs_n [None , :]) < cur_batch_ctx_len , qk , float ("-inf" )
446- )
447- qk *= sm_scale
448-
449- # -- compute m_ij, p, l_ij
450- m_ij = tl .max (qk , 1 )
451- m_i_new = tl .maximum (m_i , m_ij )
452- p = tl .math .exp (qk - m_i_new [:, None ])
453- l_ij = tl .sum (p , 1 )
454- # -- update m_i and l_i
455-
456- alpha = tl .math .exp (m_i - m_i_new )
457- l_i_new = alpha * l_i + l_ij
458- # -- update output accumulator --
459- # scale p
460- # scale acc
461- acc_scale = alpha
462- # acc_scale = l_i / l_i_new * alpha
463- acc = acc * acc_scale [:, None ]
464- # update acc
465- v = tl .load (
466- V_cache + off_v ,
467- mask = (start_n + offs_n [:, None ]) < cur_batch_ctx_len ,
468- other = 0.0 ,
469- )
470-
471- p = p .to (v .dtype )
472- acc += tl .dot (p , v )
473- # update m_i and l_i
474- l_i = l_i_new
475- m_i = m_i_new
476-
477- off_k = (
478- offs_n [None , :] * stride_kbs
479- + cur_kv_head * stride_kh
480- + offs_d [:, None ] * stride_kd
481- )
482- off_v = (
483- offs_n [:, None ] * stride_vbs
484- + cur_kv_head * stride_vh
485- + offs_d [None , :] * stride_vd
486- )
487- k_ptrs = K + off_k
488- v_ptrs = V + off_v
489-
490- block_mask = tl .where (block_start_loc < cur_batch_seq_len - cur_batch_ctx_len , 1 , 0 )
491-
492- for start_n in range (0 , block_mask * (start_m + 1 ) * BLOCK_M , BLOCK_N ):
493- start_n = tl .multiple_of (start_n , BLOCK_N )
494- # -- compute qk ----
495- k = tl .load (
496- k_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_kbs ,
497- mask = (start_n + offs_n [None , :]) < cur_batch_seq_len - cur_batch_ctx_len ,
498- other = 0.0 ,
499- )
500-
501- qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
502- qk += tl .dot (q , k )
503- qk *= sm_scale
504- qk = tl .where (offs_m [:, None ] >= (start_n + offs_n [None , :]), qk , float ("-inf" ))
505-
506- # -- compute m_ij, p, l_ij
507- m_ij = tl .max (qk , 1 )
508- m_i_new = tl .maximum (m_i , m_ij )
509- p = tl .math .exp (qk - m_i_new [:, None ])
510- l_ij = tl .sum (p , 1 )
511- # -- update m_i and l_i
512-
513- alpha = tl .math .exp (m_i - m_i_new )
514- l_i_new = alpha * l_i + l_ij
515- # -- update output accumulator --
516- # scale p
517- # scale acc
518- acc_scale = alpha
519- # acc_scale = l_i / l_i_new * alpha
520- acc = acc * acc_scale [:, None ]
521- # update acc
522- v = tl .load (
523- v_ptrs + (cur_batch_in_all_start_index + start_n ) * stride_vbs ,
524- mask = (start_n + offs_n [:, None ]) < cur_batch_seq_len - cur_batch_ctx_len ,
525- other = 0.0 ,
526- )
527-
528- p = p .to (v .dtype )
529- acc += tl .dot (p , v )
530- # update m_i and l_i
531- l_i = l_i_new
532- m_i = m_i_new
533-
534- # acc /= l_i[:, None]
535- # initialize pointers to output
536- off_o = (
537- (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_obs
538- + cur_head * stride_oh
539- + offs_d [None , :] * stride_od
540- )
541- out_ptrs = Out + off_o
542- tl .store (
543- out_ptrs , acc , mask = offs_m [:, None ] < cur_batch_seq_len - cur_batch_ctx_len
544- )
545- return
546-
547-
548338@triton .jit
549339def _fwd_kernel_alibi (
550340 Q ,
0 commit comments