1717#include  " paddle/phi/backends/xpu/enforce_xpu.h" 
1818#include  " paddle/phi/core/kernel_registry.h" 
1919#ifdef  PADDLE_WITH_XPU_XRE5
20+ #include  " paddle/phi/kernels/slice_kernel.h" 
2021#include  " paddle/phi/kernels/xpu/flash_attn_utils.h" 
2122#include  " xfa/flash_api.h" 
2223#endif 
2324namespace  phi  {
2425#ifdef  PADDLE_WITH_XPU_XRE5
2526template  <typename  T, typename  Context>
26- void  FlashAttnGradKernelBase (const  Context& ctx,
27-  const  DenseTensor& q,
28-  const  DenseTensor& k,
29-  const  DenseTensor& v,
30-  const  api::VectorParam<int >& lod_seqlen_q,
31-  const  api::VectorParam<int >& lod_seqlen_k,
32-  const  DenseTensor& out,
33-  const  DenseTensor& softmax_lse,
34-  const  DenseTensor& seed_offset,
35-  const  paddle::optional<DenseTensor>& attn_mask,
36-  const  DenseTensor& dout,
37-  const  int  batch_size,
38-  const  Scalar& max_seqlen_q_,
39-  const  Scalar& max_seqlen_k_,
40-  const  int  num_heads,
41-  const  int  num_heads_k,
42-  const  int  head_size,
43-  const  int  head_size_v,
44-  float  scale,
45-  float  dropout,
46-  bool  causal,
47-  DenseTensor* dq,
48-  DenseTensor* dk,
49-  DenseTensor* dv) {
27+ void  FlashAttnGradKernelBase (
28+  const  Context& ctx,
29+  const  DenseTensor& q,
30+  const  DenseTensor& k,
31+  const  DenseTensor& v,
32+  const  api::VectorParam<int >& lod_seqlen_q,
33+  const  api::VectorParam<int >& lod_seqlen_k,
34+  const  DenseTensor& out,
35+  const  DenseTensor& softmax_lse,
36+  const  DenseTensor& seed_offset,
37+  const  paddle::optional<DenseTensor>& attn_mask,
38+  const  paddle::optional<DenseTensor>& startend_row_indices,
39+  const  DenseTensor& dout,
40+  const  int  batch_size,
41+  const  Scalar& max_seqlen_q_,
42+  const  Scalar& max_seqlen_k_,
43+  const  int  num_heads,
44+  const  int  num_heads_k,
45+  const  int  head_size,
46+  const  int  head_size_v,
47+  float  scale,
48+  float  dropout,
49+  bool  causal,
50+  DenseTensor* dq,
51+  DenseTensor* dk,
52+  DenseTensor* dv) {
5053 xpu::ctx_guard RAII_GUARD (ctx.x_context ());
5154
5255 using  XPUType = typename  XPUTypeTrait<T>::Type;
@@ -62,7 +65,52 @@ void FlashAttnGradKernelBase(const Context& ctx,
6265
6366 const  float * bias_data = nullptr ;
6467 int64_t  fa_layout = AttnQKVLayout_t::ATTN_BLHD;
65-  if  (attn_mask.get_ptr () != nullptr ) {
68+  DenseTensor downstart_row_indices, upend_row_indices, downend_row_indices,
69+  upstart_row_indices;
70+  void  *downstart_row_indices_data = nullptr , *upend_row_indices_data = nullptr ,
71+  *downend_row_indices_data = nullptr , *upstart_row_indices_data = nullptr ;
72+  bool  is_flashmask = startend_row_indices.get_ptr () != nullptr ;
73+  XPUStream flashmask_stream;
74+  if  (is_flashmask) {
75+  xpu_stream_create (&flashmask_stream);
76+  PADDLE_ENFORCE_EQ (
77+  startend_row_indices->dims ().size (),
78+  4 ,
79+  common::errors::InvalidArgument (
80+  " flashmask_attention receive startend_row_indices with dim " 
81+  " [batch_size, num_heads,seq_len, mask_bounds]"  ));
82+  PADDLE_ENFORCE_EQ (startend_row_indices->dims ()[3 ] == 1  ||
83+  startend_row_indices->dims ()[3 ] == 2  ||
84+  startend_row_indices->dims ()[3 ] == 4 ,
85+  true ,
86+  common::errors::InvalidArgument (
87+  " flashmask_attention startend_row_indices " 
88+  " mask_bounds must in [1,2,4]"  ));
89+  downstart_row_indices =
90+  phi::Slice<int32_t >(ctx, startend_row_indices.get (), {3 }, {0 }, {1 });
91+  downstart_row_indices_data = downstart_row_indices.data ();
92+  if  (startend_row_indices->dims ()[3 ] == 2 ) {
93+  if  (!causal) {
94+  upend_row_indices =
95+  phi::Slice<int32_t >(ctx, startend_row_indices.get (), {3 }, {1 }, {2 });
96+  upend_row_indices_data = upend_row_indices.data ();
97+  } else  {
98+  downend_row_indices =
99+  phi::Slice<int32_t >(ctx, startend_row_indices.get (), {3 }, {1 }, {2 });
100+  downend_row_indices_data = downend_row_indices.data ();
101+  }
102+  } else  if  (startend_row_indices->dims ()[3 ] == 4 ) {
103+  upend_row_indices =
104+  phi::Slice<int32_t >(ctx, startend_row_indices.get (), {3 }, {3 }, {4 });
105+  upend_row_indices_data = upend_row_indices.data ();
106+  downend_row_indices =
107+  phi::Slice<int32_t >(ctx, startend_row_indices.get (), {3 }, {1 }, {2 });
108+  downend_row_indices_data = downend_row_indices.data ();
109+  upstart_row_indices =
110+  phi::Slice<int32_t >(ctx, startend_row_indices.get (), {3 }, {2 }, {3 });
111+  upstart_row_indices_data = upstart_row_indices.data ();
112+  }
113+  } else  if  (attn_mask.get_ptr () != nullptr ) {
66114 const  auto & mask_dims = attn_mask->dims ();
67115 if  (mask_dims.size () == 3  || (mask_dims[1 ] == 1  && mask_dims.size () == 4 )) {
68116 fa_layout |= AttnQKVLayout_t::BIAS_BLL;
@@ -175,15 +223,17 @@ void FlashAttnGradKernelBase(const Context& ctx,
175223 -1 , //  window_size_left
176224 -1 , //  window_size_right
177225 head_size_v, //  v_head_dim
178-  nullptr , //  downstart_row_indices_data
179-  nullptr , //  downend_row_indices_data
180-  nullptr , //  upstart_row_indices_data
181-  nullptr , //  upend_row_indices_data
182-  0 , //  flash_mask_head_num
183-  nullptr , //  flashmask_maxmin
184-  nullptr  //  side_stream
185-  );
226+  (const  int *)downstart_row_indices_data,
227+  (const  int *)downend_row_indices_data,
228+  (const  int *)upstart_row_indices_data,
229+  (const  int *)upend_row_indices_data,
230+  is_flashmask ? startend_row_indices->dims ()[1 ] : 0 ,
231+  nullptr ,
232+  is_flashmask ? flashmask_stream : nullptr );
186233 PADDLE_ENFORCE_XDNN_SUCCESS (r, " mha_varlen_bwd"  );
234+  if  (is_flashmask && flashmask_stream != nullptr ) {
235+  xpu_stream_destroy (flashmask_stream);
236+  }
187237}
188238#endif 
189239
@@ -236,6 +286,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
236286 softmax_lse,
237287 seed_offset,
238288 attn_mask,
289+  paddle::none,
239290 dout,
240291 batch_size,
241292 max_seqlen_q,
@@ -316,6 +367,7 @@ void FlashAttnGradKernel(const Context& ctx,
316367 softmax_lse,
317368 seed_offset,
318369 attn_mask,
370+  paddle::none,
319371 dout,
320372 batch_size,
321373 seqlen_q,
@@ -336,6 +388,85 @@ void FlashAttnGradKernel(const Context& ctx,
336388#endif 
337389}
338390
391+ template  <typename  T, typename  Context>
392+ void  FlashMaskGradKernel (const  Context& ctx,
393+  const  DenseTensor& q,
394+  const  DenseTensor& k,
395+  const  DenseTensor& v,
396+  const  DenseTensor& startend_row_indices,
397+  const  DenseTensor& out,
398+  const  DenseTensor& softmax_lse,
399+  const  DenseTensor& seed_offset,
400+  const  DenseTensor& dout,
401+  float  dropout,
402+  bool  causal,
403+  DenseTensor* dq,
404+  DenseTensor* dk,
405+  DenseTensor* dv) {
406+ #ifdef  PADDLE_WITH_XPU_XRE5
407+  ctx.template  Alloc <T>(dq);
408+  ctx.template  Alloc <T>(dk);
409+  ctx.template  Alloc <T>(dv);
410+ 
411+  //  q, k, v [batch_size, seq_len, num_heads, head_dim]
412+  const  auto & dims = q.dims ();
413+ 
414+  const  int64_t  batch_size = dims[0 ];
415+  const  int64_t  seqlen_q = dims[1 ];
416+  const  int64_t  num_heads = dims[2 ];
417+  const  int64_t  head_size_og = dout.dims ()[3 ];
418+  const  int64_t  head_size = dims[3 ];
419+  const  int64_t  head_size_v = v.dims ()[3 ];
420+  const  int64_t  seqlen_k = k.dims ()[1 ];
421+  const  int64_t  num_heads_k = k.dims ()[2 ];
422+ 
423+  PADDLE_ENFORCE_EQ (
424+  head_size_og,
425+  head_size_v,
426+  common::errors::InvalidArgument (
427+  " flash_attn_bwd receive input with head_size_og == head_size_v"  ));
428+ 
429+  //  lod info
430+  std::vector<int > qlod_vec = {0 };
431+  std::vector<int > kvlod_vec = {0 };
432+  for  (int  batch_idx = 1 ; batch_idx <= batch_size; ++batch_idx) {
433+  qlod_vec.push_back (seqlen_q * batch_idx);
434+  kvlod_vec.push_back (seqlen_k * batch_idx);
435+  }
436+  api::VectorParam<int > qlod{
437+  qlod_vec.data (), static_cast <int64_t >(qlod_vec.size ()), nullptr };
438+  api::VectorParam<int > kvlod{
439+  kvlod_vec.data (), static_cast <int64_t >(kvlod_vec.size ()), nullptr };
440+  FlashAttnGradKernelBase<T>(ctx,
441+  q,
442+  k,
443+  v,
444+  qlod,
445+  kvlod,
446+  out,
447+  softmax_lse,
448+  seed_offset,
449+  paddle::none,
450+  startend_row_indices,
451+  dout,
452+  batch_size,
453+  seqlen_q,
454+  seqlen_k,
455+  num_heads,
456+  num_heads_k,
457+  head_size,
458+  head_size_v,
459+  0.0 ,
460+  dropout,
461+  causal,
462+  dq,
463+  dk,
464+  dv);
465+ #else 
466+  PADDLE_THROW (common::errors::Unimplemented (
467+  " re-compile using -DWITH_XPU_XRE5=ON to use FlashMaskGradKernel"  ));
468+ #endif 
469+ }
339470} //  namespace phi
340471
341472PD_REGISTER_KERNEL (flash_attn_unpadded_grad,
@@ -359,3 +490,13 @@ PD_REGISTER_KERNEL(flash_attn_grad,
359490 phi::dtype::float16) {
360491 kernel->InputAt (5 ).SetBackend (phi::Backend::ALL_BACKEND); //  seed_offset
361492}
493+ 
494+ PD_REGISTER_KERNEL (flashmask_attention_grad,
495+  XPU,
496+  ALL_LAYOUT,
497+  phi::FlashMaskGradKernel,
498+  phi::dtype::float16,
499+  phi::dtype::bfloat16) {
500+  kernel->InputAt (6 ).SetBackend (
501+  phi::Backend::ALL_BACKEND); //  fixed_seed_offset
502+ }
0 commit comments