Skip to content

Commit e088d36

Browse files
committed
mp support for fuse attention
cache structure support for fuse attention
1 parent 60b86b2 commit e088d36

File tree

4 files changed

+128
-27
lines changed

4 files changed

+128
-27
lines changed

paddle/fluid/operators/fused/fmha_ref.h

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
1616
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
1717
#include "paddle/fluid/operators/transpose_op.cu.h"
18+
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
19+
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
1820
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
1921

2022
namespace paddle {
@@ -74,35 +76,57 @@ class FMHARef {
7476
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
7577
Tensor* dropout_mask_out_tensor,
7678
Tensor* dropout_out_tensor, Tensor* qktv_out_tensor,
77-
Tensor* fmha_out_tensor) {
79+
Tensor* fmha_out_tensor, const Tensor* cache_k,
80+
const Tensor* cache_v, Tensor* cache_k_out,
81+
Tensor* cache_v_out) {
7882
// input shape: [bs, seq_len, 3, num_head, head_dim]
79-
// transpose with perm [2, 0, 1, 3, 4],
83+
// transpose with perm [2, 0, 3, 1, 4],
8084
// output_shape: [3, bs, num_head, seq_len, head_dim]
8185
int ndims = 5;
8286
std::vector<int> perm_1 = {2, 0, 3, 1, 4};
8387
TransposeGPUKernelDriver<T>(dev_ctx_, ndims, qkv_input_tensor, perm_1,
8488
transpose_2_out_tensor);
85-
8689
T* qkv_data = transpose_2_out_tensor->data<T>();
8790
T* qk_out_data = qk_out_tensor->data<T>();
8891
T* qktv_out_data = qktv_out_tensor->data<T>();
8992
T* softmax_out_data = softmax_out_tensor->data<T>();
9093
T* dropout_out_data = dropout_out_tensor->data<T>();
9194
T* fmha_out_data = fmha_out_tensor->data<T>();
95+
const T* cache_k_data = cache_k ? cache_k->data<T>() : nullptr;
96+
const T* cache_v_data = cache_k ? cache_v->data<T>() : nullptr;
97+
int cache_size = 0;
98+
int cache_seq_len = 0;
99+
if (cache_k) {
100+
cache_size = cache_k->numel();
101+
cache_seq_len = cache_k->dims()[2];
102+
}
92103

93104
int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
94105
int k_size = q_size;
106+
int new_k_size = cache_size + k_size;
95107
T* q_ptr = qkv_data;
96108
T* k_ptr = q_ptr + q_size;
97109
T* v_ptr = k_ptr + k_size;
110+
if (cache_k) {
111+
std::vector<Tensor> qkv = transpose_2_out_tensor->Split(1, 0);
112+
int64_t kdims[4] = {qkv[1].dims()[1], qkv[1].dims()[2], qkv[1].dims()[3],
113+
qkv[1].dims()[4]};
114+
qkv[1].Resize(phi::DDim(kdims, 4));
115+
qkv[2].Resize(phi::DDim(kdims, 4));
116+
phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
117+
concat(dev_ctx_, {*cache_k, qkv[1]}, 2, cache_k_out);
118+
concat(dev_ctx_, {*cache_v, qkv[2]}, 2, cache_v_out);
119+
k_ptr = cache_k_out->data<T>();
120+
v_ptr = cache_v_out->data<T>();
121+
}
98122

99123
// q*k^t, batched_gemm
100124
CBLAS_TRANSPOSE transA = CblasNoTrans;
101125
CBLAS_TRANSPOSE transB = CblasTrans;
102126
auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
103127
int gemm_batch_size = batch_size_ * num_head_;
104128
int gemm_m = seq_len_;
105-
int gemm_n = seq_len_;
129+
int gemm_n = cache_seq_len + seq_len_;
106130
int gemm_k = head_dim_;
107131
T alpha = static_cast<T>(1.0 / sqrt(head_dim_));
108132
T beta = static_cast<T>(0.0);
@@ -133,7 +157,7 @@ class FMHARef {
133157
transB = CblasNoTrans;
134158
gemm_m = seq_len_;
135159
gemm_n = head_dim_;
136-
gemm_k = seq_len_;
160+
gemm_k = cache_seq_len + seq_len_;
137161
alpha = static_cast<T>(1.0);
138162
stride_a = gemm_m * gemm_k;
139163
stride_b = gemm_k * gemm_n;

paddle/fluid/operators/fused/fused_attention_op.cc

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
105105
"input qkv_weight = [%s]",
106106
x_dim, y_dim));
107107

108-
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
109-
platform::errors::InvalidArgument(
110-
"The dimensions of qkv_weight must be 4"
111-
"(3, num_head, dim_head, dim_embed),"
112-
"and must satisfy the limitations: "
113-
"(num_head * dim_head == dim_embed)"));
108+
if (ctx->Attrs().Get<int>("ring_id") == -1) {
109+
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
110+
platform::errors::InvalidArgument(
111+
"The dimensions of qkv_weight must be 4"
112+
"(3, num_head, dim_head, dim_embed),"
113+
"and must satisfy the limitations: "
114+
"(num_head * dim_head == dim_embed)"));
115+
}
114116

115117
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
116118
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
@@ -133,19 +135,28 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
133135
ctx->SetOutputDim("TransposeOut2",
134136
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
135137
// [batch, num_head, seq_len, seq_len]
136-
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
138+
auto last_dim = x_dim[1];
139+
if (ctx->HasInput("CacheK")) {
140+
auto cache_dim = ctx->GetInputDim("CacheK");
141+
last_dim += cache_dim[2];
142+
ctx->SetOutputDim("CacheKOut",
143+
{cache_dim[0], cache_dim[1], last_dim, cache_dim[3]});
144+
ctx->SetOutputDim("CacheVOut",
145+
{cache_dim[0], cache_dim[1], last_dim, cache_dim[3]});
146+
}
147+
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], last_dim});
137148

138149
if (ctx->HasInput("SrcMask")) {
139-
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
150+
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], last_dim});
140151
}
141152
// the same as QKOut's shape.
142153
ctx->SetOutputDim("AttnDropoutOut",
143-
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
154+
{x_dim[0], y_dim[1], x_dim[1], last_dim});
144155
if (ctx->Attrs().Get<bool>("attn_dropout_is_test") == false) {
145156
ctx->SetOutputDim("AttnDropoutMaskOut",
146-
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
157+
{x_dim[0], y_dim[1], x_dim[1], last_dim});
147158
}
148-
ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
159+
ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], last_dim});
149160
// [batch_size, num_heads, seq_len, head_dim]
150161
ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
151162
// [batch_size, seq_len, number of heads*head size]
@@ -194,6 +205,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
194205
"(optional) Bias is a 1-dimensional tensor of size "
195206
"H. Here, H represents the last dimension of its input tensor.")
196207
.AsDispensable();
208+
AddInput("CacheK", "(optional) The cached K for generation inference.")
209+
.AsDispensable();
210+
AddInput("CacheV", "(optional) The cached V for generation inference.")
211+
.AsDispensable();
197212
AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
198213
AddOutput("LnVariance", "Variance of the current mini batch.")
199214
.AsIntermediate();
@@ -217,6 +232,8 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
217232
AddOutput("BiasDropoutResidualOut",
218233
"Result of residual + dropout(src + bias).")
219234
.AsIntermediate();
235+
AddOutput("CacheKOut", "The udpated cache K.");
236+
AddOutput("CacheVOut", "The udpated cache V.");
220237
AddOutput("Y", "Result after attention.");
221238

222239
AddAttr<bool>("pre_layer_norm",
@@ -324,6 +341,10 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
324341
"0.0 and 0.001, But received [%s].",
325342
ln_epsilon));
326343
});
344+
AddAttr<int>(
345+
"ring_id",
346+
"ring id for tensor model parallel. distributed training and inference")
347+
.SetDefault(-1);
327348

328349
AddComment(R"DOC(
329350
Add fused attention op whose logic is as follows:

paddle/fluid/operators/fused/fused_attention_op.cu

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ limitations under the License. */
2727
#include "paddle/fluid/operators/fused/fmha_ref.h"
2828
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
2929

30+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
31+
#include "paddle/fluid/platform/collective_helper.h"
32+
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
33+
#endif
34+
3035
namespace paddle {
3136
namespace operators {
3237

@@ -51,6 +56,10 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
5156
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
5257
auto *qkv_weight = ctx.Input<Tensor>("QKVW");
5358
auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
59+
auto *cache_k = ctx.Input<Tensor>("CacheK");
60+
auto *cache_v = ctx.Input<Tensor>("CacheV");
61+
auto *cache_k_out = ctx.Output<Tensor>("CacheKOut");
62+
auto *cache_v_out = ctx.Output<Tensor>("CacheVOut");
5463
auto *qkv_out = ctx.Output<Tensor>("QKVOut");
5564
auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");
5665

@@ -86,6 +95,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
8695
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
8796
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
8897
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
98+
int ring_id = ctx.Attr<int>("ring_id");
8999

90100
// final output.
91101
auto *out = ctx.Output<Tensor>("Y");
@@ -128,6 +138,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
128138
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
129139
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());
130140

141+
auto *cache_k_out_data =
142+
cache_k_out ? cache_k_out->mutable_data<T>(ctx.GetPlace()) : nullptr;
143+
auto *cache_v_out_data =
144+
cache_v_out ? cache_v_out->mutable_data<T>(ctx.GetPlace()) : nullptr;
145+
131146
int batch_size = input_x_dims[0];
132147
int max_seq_len = input_x_dims[1];
133148
int dim_embed = input_x_dims[2];
@@ -161,9 +176,14 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
161176

162177
output_size = hidden_size;
163178
// (transA, transB, compute_bias) = (false, false, false)
179+
// NOTE(Yuang Liu): For general input size == output size, change the
180+
// position won't have effects. For mp, the output size is mp_head * dkey
181+
// which is actually the input size. While the input size is hidden size,
182+
// which is actually the output size. So for out linear, switch the
183+
// input size and output size.
164184
auto out_linear_compute =
165185
AttnMatMul<T>(ctx.cuda_device_context(), false, false, bsz_seq,
166-
output_size, input_size, false);
186+
input_size, output_size, false);
167187
DropoutParam dropout_param2(ctx, 0);
168188
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
169189
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
@@ -186,22 +206,41 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
186206
qkv_bias_out);
187207
}
188208
if (qkv_bias == nullptr) {
189-
fmha_ref_compute.ComputeForward(*qkv_out, src_mask, transpose_out_2,
190-
qk_out, src_mask_out, softmax_out,
191-
attn_dropout_mask_out, attn_dropout_out,
192-
qktv_out, fmha_out);
209+
fmha_ref_compute.ComputeForward(
210+
*qkv_out, src_mask, transpose_out_2, qk_out, src_mask_out,
211+
softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out,
212+
fmha_out, cache_k, cache_v, cache_k_out, cache_v_out);
193213
} else {
194-
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
195-
qk_out, src_mask_out, softmax_out,
196-
attn_dropout_mask_out, attn_dropout_out,
197-
qktv_out, fmha_out);
214+
fmha_ref_compute.ComputeForward(
215+
*qkv_bias_out, src_mask, transpose_out_2, qk_out, src_mask_out,
216+
softmax_out, attn_dropout_mask_out, attn_dropout_out, qktv_out,
217+
fmha_out, cache_k, cache_v, cache_k_out, cache_v_out);
198218
}
199219

200220
// fmha_out: [batch_size, seq_len, num_head, head_dim]
201221
// weight: [embed_dim, embed_dim]
202222
// out_linear_out: [batch_size, seq_len, embed_dim]
203223
out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
204224
out_linear_out, nullptr);
225+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
226+
if (ring_id >= 0) {
227+
ncclDataType_t dtype = platform::ToNCCLDataType(
228+
framework::TransToProtoVarType(out_linear_out->dtype()));
229+
auto place = ctx.GetPlace();
230+
int64_t numel = out_linear_out->numel();
231+
const void *sendbuff = out_linear_out->data<T>();
232+
void *recvbuff = out_linear_out->mutable_data<T>(place);
233+
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
234+
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
235+
gpuStream_t stream =
236+
static_cast<platform::CUDADeviceContext *>(dev_ctx)->stream();
237+
ncclRedOp_t nccl_red_type = ncclSum;
238+
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
239+
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(),
240+
stream));
241+
}
242+
#endif
243+
205244
if (pre_layer_norm) {
206245
// output = (residual + dropout(input + bias))
207246
fused_dropout_layernorm_helper.ResidualDropoutBias(

python/paddle/incubate/nn/functional/fused_transformer.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ def fused_multi_head_attention(x,
229229
ln_epsilon=1e-05,
230230
training=True,
231231
mode='upscale_in_train',
232-
name=None):
232+
name=None,
233+
ring_id=-1,
234+
cache_k=None,
235+
cache_v=None):
233236
r"""
234237
Attention mapps queries and a set of key-value pairs to outputs, and
235238
Multi-Head Attention performs multiple parallel attention to jointly attending
@@ -304,6 +307,9 @@ def fused_multi_head_attention(x,
304307
- train: out = input * mask
305308
- inference: out = input * (1.0 - p)
306309
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
310+
ring_id (int, optional): For distributed forward in mp, only support NCCL and forward. Default is -1, means not using mp
311+
cache_k (Tensor, optional): For generation model, cache structure
312+
cache_v (Tensor, optional): For generation model, cache structure
307313
308314
Returns:
309315
Tensor: The output Tensor, the data type and shape is same as `x`.
@@ -398,6 +404,9 @@ def fused_multi_head_attention(x,
398404
inputs['Ln2Scale'] = [ln_scale]
399405
if ln_bias:
400406
inputs['Ln2Bias'] = [ln_bias]
407+
if cache_k:
408+
inputs['CacheK'] = [cache_k]
409+
inputs['CacheV'] = [cache_v]
401410

402411
if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
403412
seed = helper.main_program.random_seed
@@ -417,6 +426,7 @@ def fused_multi_head_attention(x,
417426
'dropout_seed': seed if seed is not None else 0,
418427
'attn_dropout_implementation': mode,
419428
'dropout_implementation': mode,
429+
'ring_id': ring_id
420430
}
421431

422432
# set outputs
@@ -449,6 +459,8 @@ def fused_multi_head_attention(x,
449459
bias_dropout_residual_out = helper.create_variable_for_type_inference(
450460
dtype=dtype)
451461
final_out = helper.create_variable_for_type_inference(dtype=dtype)
462+
cache_k_out = helper.create_variable_for_type_inference(dtype=dtype)
463+
cache_v_out = helper.create_variable_for_type_inference(dtype=dtype)
452464

453465
helper.append_op(
454466
type='fused_attention',
@@ -472,7 +484,12 @@ def fused_multi_head_attention(x,
472484
"Ln2Mean": ln_mean_out,
473485
"Ln2Variance": ln_variance_out,
474486
"BiasDropoutResidualOut": bias_dropout_residual_out,
475-
'Y': final_out
487+
'Y': final_out,
488+
'CacheKOut': cache_k_out,
489+
'CacheVOut': cache_v_out
476490
},
477491
attrs=attrs)
492+
493+
if cache_k:
494+
return [final_out, cache_k_out, cache_v_out]
478495
return final_out

0 commit comments

Comments
 (0)