Skip to content

Commit 11a95ee

Browse files
authored
[XPU] support flashmask_attention (PaddlePaddle#71573)
1 parent 5235884 commit 11a95ee

File tree

4 files changed

+738
-126
lines changed

4 files changed

+738
-126
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ XPUOpMap& get_kl3_ops() {
615615
XPUKernelSet({phi::DataType::BFLOAT16,
616616
phi::DataType::FLOAT32,
617617
phi::DataType::FLOAT16})},
618+
{"flashmask_attention",
619+
XPUKernelSet({phi::DataType::BFLOAT16, phi::DataType::FLOAT16})},
620+
{"flashmask_attention_grad",
621+
XPUKernelSet({phi::DataType::BFLOAT16, phi::DataType::FLOAT16})},
618622
{"flash_attn_unpadded",
619623
XPUKernelSet({phi::DataType::BFLOAT16,
620624
phi::DataType::FLOAT32,

paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc

Lines changed: 174 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,39 @@
1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#ifdef PADDLE_WITH_XPU_XRE5
20+
#include "paddle/phi/kernels/slice_kernel.h"
2021
#include "paddle/phi/kernels/xpu/flash_attn_utils.h"
2122
#include "xfa/flash_api.h"
2223
#endif
2324
namespace phi {
2425
#ifdef PADDLE_WITH_XPU_XRE5
2526
template <typename T, typename Context>
26-
void FlashAttnGradKernelBase(const Context& ctx,
27-
const DenseTensor& q,
28-
const DenseTensor& k,
29-
const DenseTensor& v,
30-
const api::VectorParam<int>& lod_seqlen_q,
31-
const api::VectorParam<int>& lod_seqlen_k,
32-
const DenseTensor& out,
33-
const DenseTensor& softmax_lse,
34-
const DenseTensor& seed_offset,
35-
const paddle::optional<DenseTensor>& attn_mask,
36-
const DenseTensor& dout,
37-
const int batch_size,
38-
const Scalar& max_seqlen_q_,
39-
const Scalar& max_seqlen_k_,
40-
const int num_heads,
41-
const int num_heads_k,
42-
const int head_size,
43-
const int head_size_v,
44-
float scale,
45-
float dropout,
46-
bool causal,
47-
DenseTensor* dq,
48-
DenseTensor* dk,
49-
DenseTensor* dv) {
27+
void FlashAttnGradKernelBase(
28+
const Context& ctx,
29+
const DenseTensor& q,
30+
const DenseTensor& k,
31+
const DenseTensor& v,
32+
const api::VectorParam<int>& lod_seqlen_q,
33+
const api::VectorParam<int>& lod_seqlen_k,
34+
const DenseTensor& out,
35+
const DenseTensor& softmax_lse,
36+
const DenseTensor& seed_offset,
37+
const paddle::optional<DenseTensor>& attn_mask,
38+
const paddle::optional<DenseTensor>& startend_row_indices,
39+
const DenseTensor& dout,
40+
const int batch_size,
41+
const Scalar& max_seqlen_q_,
42+
const Scalar& max_seqlen_k_,
43+
const int num_heads,
44+
const int num_heads_k,
45+
const int head_size,
46+
const int head_size_v,
47+
float scale,
48+
float dropout,
49+
bool causal,
50+
DenseTensor* dq,
51+
DenseTensor* dk,
52+
DenseTensor* dv) {
5053
xpu::ctx_guard RAII_GUARD(ctx.x_context());
5154

5255
using XPUType = typename XPUTypeTrait<T>::Type;
@@ -62,7 +65,52 @@ void FlashAttnGradKernelBase(const Context& ctx,
6265

6366
const float* bias_data = nullptr;
6467
int64_t fa_layout = AttnQKVLayout_t::ATTN_BLHD;
65-
if (attn_mask.get_ptr() != nullptr) {
68+
DenseTensor downstart_row_indices, upend_row_indices, downend_row_indices,
69+
upstart_row_indices;
70+
void *downstart_row_indices_data = nullptr, *upend_row_indices_data = nullptr,
71+
*downend_row_indices_data = nullptr, *upstart_row_indices_data = nullptr;
72+
bool is_flashmask = startend_row_indices.get_ptr() != nullptr;
73+
XPUStream flashmask_stream;
74+
if (is_flashmask) {
75+
xpu_stream_create(&flashmask_stream);
76+
PADDLE_ENFORCE_EQ(
77+
startend_row_indices->dims().size(),
78+
4,
79+
common::errors::InvalidArgument(
80+
"flashmask_attention receive startend_row_indices with dim "
81+
"[batch_size, num_heads,seq_len, mask_bounds]"));
82+
PADDLE_ENFORCE_EQ(startend_row_indices->dims()[3] == 1 ||
83+
startend_row_indices->dims()[3] == 2 ||
84+
startend_row_indices->dims()[3] == 4,
85+
true,
86+
common::errors::InvalidArgument(
87+
"flashmask_attention startend_row_indices "
88+
"mask_bounds must in [1,2,4]"));
89+
downstart_row_indices =
90+
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {0}, {1});
91+
downstart_row_indices_data = downstart_row_indices.data();
92+
if (startend_row_indices->dims()[3] == 2) {
93+
if (!causal) {
94+
upend_row_indices =
95+
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {1}, {2});
96+
upend_row_indices_data = upend_row_indices.data();
97+
} else {
98+
downend_row_indices =
99+
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {1}, {2});
100+
downend_row_indices_data = downend_row_indices.data();
101+
}
102+
} else if (startend_row_indices->dims()[3] == 4) {
103+
upend_row_indices =
104+
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {3}, {4});
105+
upend_row_indices_data = upend_row_indices.data();
106+
downend_row_indices =
107+
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {1}, {2});
108+
downend_row_indices_data = downend_row_indices.data();
109+
upstart_row_indices =
110+
phi::Slice<int32_t>(ctx, startend_row_indices.get(), {3}, {2}, {3});
111+
upstart_row_indices_data = upstart_row_indices.data();
112+
}
113+
} else if (attn_mask.get_ptr() != nullptr) {
66114
const auto& mask_dims = attn_mask->dims();
67115
if (mask_dims.size() == 3 || (mask_dims[1] == 1 && mask_dims.size() == 4)) {
68116
fa_layout |= AttnQKVLayout_t::BIAS_BLL;
@@ -175,15 +223,17 @@ void FlashAttnGradKernelBase(const Context& ctx,
175223
-1, // window_size_left
176224
-1, // window_size_right
177225
head_size_v, // v_head_dim
178-
nullptr, // downstart_row_indices_data
179-
nullptr, // downend_row_indices_data
180-
nullptr, // upstart_row_indices_data
181-
nullptr, // upend_row_indices_data
182-
0, // flash_mask_head_num
183-
nullptr, // flashmask_maxmin
184-
nullptr // side_stream
185-
);
226+
(const int*)downstart_row_indices_data,
227+
(const int*)downend_row_indices_data,
228+
(const int*)upstart_row_indices_data,
229+
(const int*)upend_row_indices_data,
230+
is_flashmask ? startend_row_indices->dims()[1] : 0,
231+
nullptr,
232+
is_flashmask ? flashmask_stream : nullptr);
186233
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
234+
if (is_flashmask && flashmask_stream != nullptr) {
235+
xpu_stream_destroy(flashmask_stream);
236+
}
187237
}
188238
#endif
189239

@@ -236,6 +286,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
236286
softmax_lse,
237287
seed_offset,
238288
attn_mask,
289+
paddle::none,
239290
dout,
240291
batch_size,
241292
max_seqlen_q,
@@ -316,6 +367,7 @@ void FlashAttnGradKernel(const Context& ctx,
316367
softmax_lse,
317368
seed_offset,
318369
attn_mask,
370+
paddle::none,
319371
dout,
320372
batch_size,
321373
seqlen_q,
@@ -336,6 +388,85 @@ void FlashAttnGradKernel(const Context& ctx,
336388
#endif
337389
}
338390

391+
template <typename T, typename Context>
392+
void FlashMaskGradKernel(const Context& ctx,
393+
const DenseTensor& q,
394+
const DenseTensor& k,
395+
const DenseTensor& v,
396+
const DenseTensor& startend_row_indices,
397+
const DenseTensor& out,
398+
const DenseTensor& softmax_lse,
399+
const DenseTensor& seed_offset,
400+
const DenseTensor& dout,
401+
float dropout,
402+
bool causal,
403+
DenseTensor* dq,
404+
DenseTensor* dk,
405+
DenseTensor* dv) {
406+
#ifdef PADDLE_WITH_XPU_XRE5
407+
ctx.template Alloc<T>(dq);
408+
ctx.template Alloc<T>(dk);
409+
ctx.template Alloc<T>(dv);
410+
411+
// q, k, v [batch_size, seq_len, num_heads, head_dim]
412+
const auto& dims = q.dims();
413+
414+
const int64_t batch_size = dims[0];
415+
const int64_t seqlen_q = dims[1];
416+
const int64_t num_heads = dims[2];
417+
const int64_t head_size_og = dout.dims()[3];
418+
const int64_t head_size = dims[3];
419+
const int64_t head_size_v = v.dims()[3];
420+
const int64_t seqlen_k = k.dims()[1];
421+
const int64_t num_heads_k = k.dims()[2];
422+
423+
PADDLE_ENFORCE_EQ(
424+
head_size_og,
425+
head_size_v,
426+
common::errors::InvalidArgument(
427+
"flash_attn_bwd receive input with head_size_og == head_size_v"));
428+
429+
// lod info
430+
std::vector<int> qlod_vec = {0};
431+
std::vector<int> kvlod_vec = {0};
432+
for (int batch_idx = 1; batch_idx <= batch_size; ++batch_idx) {
433+
qlod_vec.push_back(seqlen_q * batch_idx);
434+
kvlod_vec.push_back(seqlen_k * batch_idx);
435+
}
436+
api::VectorParam<int> qlod{
437+
qlod_vec.data(), static_cast<int64_t>(qlod_vec.size()), nullptr};
438+
api::VectorParam<int> kvlod{
439+
kvlod_vec.data(), static_cast<int64_t>(kvlod_vec.size()), nullptr};
440+
FlashAttnGradKernelBase<T>(ctx,
441+
q,
442+
k,
443+
v,
444+
qlod,
445+
kvlod,
446+
out,
447+
softmax_lse,
448+
seed_offset,
449+
paddle::none,
450+
startend_row_indices,
451+
dout,
452+
batch_size,
453+
seqlen_q,
454+
seqlen_k,
455+
num_heads,
456+
num_heads_k,
457+
head_size,
458+
head_size_v,
459+
0.0,
460+
dropout,
461+
causal,
462+
dq,
463+
dk,
464+
dv);
465+
#else
466+
PADDLE_THROW(common::errors::Unimplemented(
467+
"re-compile using -DWITH_XPU_XRE5=ON to use FlashMaskGradKernel"));
468+
#endif
469+
}
339470
} // namespace phi
340471

341472
PD_REGISTER_KERNEL(flash_attn_unpadded_grad,
@@ -359,3 +490,13 @@ PD_REGISTER_KERNEL(flash_attn_grad,
359490
phi::dtype::float16) {
360491
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
361492
}
493+
494+
PD_REGISTER_KERNEL(flashmask_attention_grad,
495+
XPU,
496+
ALL_LAYOUT,
497+
phi::FlashMaskGradKernel,
498+
phi::dtype::float16,
499+
phi::dtype::bfloat16) {
500+
kernel->InputAt(6).SetBackend(
501+
phi::Backend::ALL_BACKEND); // fixed_seed_offset
502+
}

0 commit comments

Comments
 (0)