Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 150 additions & 14 deletions paddle/phi/kernels/fusion/gpu/fused_attention_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
Expand Down Expand Up @@ -106,6 +107,130 @@ void FusedAttentionGradKernel(
DenseTensor *fmha_out_grad,
DenseTensor *out_linear_out_grad) {
using U = phi::fusion::LayerNormParamType<T>;
if (x.numel() == 0) {
if (qkv_bias_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(qkv_bias_grad->dims())),
0,
qkv_bias_grad);
if (qkv_bias_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(qkv_bias_out_grad->dims())),
0,
qkv_bias_out_grad);
if (src_mask_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(src_mask_out_grad->dims())),
0,
src_mask_out_grad);
if (out_linear_bias_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(out_linear_bias_grad->dims())),
0,
out_linear_bias_grad);
if (ln_scale_grad)
phi::Full<U, Context>(
dev_ctx,
phi::IntArray(common::vectorize(ln_scale_grad->dims())),
0,
ln_scale_grad);
if (ln_bias_grad)
phi::Full<U, Context>(
dev_ctx,
phi::IntArray(common::vectorize(ln_bias_grad->dims())),
0,
ln_bias_grad);
if (ln_scale_2_grad)
phi::Full<U, Context>(
dev_ctx,
phi::IntArray(common::vectorize(ln_scale_2_grad->dims())),
0,
ln_scale_2_grad);
if (ln_bias_2_grad)
phi::Full<U, Context>(
dev_ctx,
phi::IntArray(common::vectorize(ln_bias_2_grad->dims())),
0,
ln_bias_2_grad);
if (x_grad) dev_ctx.template Alloc<T>(x_grad);
if (qkv_weight_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(qkv_weight_grad->dims())),
0,
qkv_weight_grad);
if (out_linear_weight_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(out_linear_weight_grad->dims())),
0,
out_linear_weight_grad);
if (ln_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(ln_out_grad->dims())),
0,
ln_out_grad);
if (bias_dropout_residual_out_grad)
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(
bias_dropout_residual_out_grad->dims())),
0,
bias_dropout_residual_out_grad);
if (qkv_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(qkv_out_grad->dims())),
0,
qkv_out_grad);
if (qktv_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(qktv_out_grad->dims())),
0,
qktv_out_grad);
if (transpose_out_2_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(transpose_out_2_grad->dims())),
0,
transpose_out_2_grad);
if (qk_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(qk_out_grad->dims())),
0,
qk_out_grad);
if (softmax_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(softmax_out_grad->dims())),
0,
softmax_out_grad);
if (attn_dropout_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(attn_dropout_out_grad->dims())),
0,
attn_dropout_out_grad);
if (fmha_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(fmha_out_grad->dims())),
0,
fmha_out_grad);
if (out_linear_out_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(out_linear_out_grad->dims())),
0,
out_linear_out_grad);
return;
}

const bool has_attn_dropout = (attn_dropout_rate != 0.0f);

Expand Down Expand Up @@ -323,20 +448,31 @@ void FusedAttentionGradKernel(
bias_dropout_residual_out_grad,
bias_dropout_residual_out_grad->numel() * sizeof(T));

fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
dev_ctx,
d_y_data,
bias_dropout_residual_out_data,
dropout_mask_out_data,
ln_2_scale_data,
ln_mean_2_data,
ln_var_2_data,
d_bias_dropout_residual_out_data,
d_ln_2_scale_data,
d_ln_bias_2_data,
d_out_linear_out_data,
d_out_linear_bias_data,
d_residual_data);
bool ln_0_size = ln_scale_2_p && ln_scale_2_p->numel() == 0;
if (ln_0_size) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
dev_ctx,
d_y_data,
dropout_mask_out_data,
d_out_linear_out_data,
d_residual_data,
d_out_linear_bias_data);
} else {
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
dev_ctx,
d_y_data,
bias_dropout_residual_out_data,
dropout_mask_out_data,
ln_2_scale_data,
ln_mean_2_data,
ln_var_2_data,
d_bias_dropout_residual_out_data,
d_ln_2_scale_data,
d_ln_bias_2_data,
d_out_linear_out_data,
d_out_linear_bias_data,
d_residual_data);
}
}

out_linear_compute.ComputeBackward(fmha_out_p,
Expand Down
39 changes: 39 additions & 0 deletions paddle/phi/kernels/fusion/gpu/fused_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
Expand Down Expand Up @@ -86,6 +87,31 @@ void FusedAttentionKernel(const Context &dev_ctx,
DenseTensor *cache_kv_out,
DenseTensor *out) {
using U = phi::funcs::LayerNormParamType<T>;
if (x.numel() == 0) {
if (ln_mean) dev_ctx.template Alloc<U>(ln_mean);
if (ln_var) dev_ctx.template Alloc<U>(ln_var);
if (ln_out) dev_ctx.template Alloc<T>(ln_out);
if (qkv_out) dev_ctx.template Alloc<T>(qkv_out);
if (qkv_bias_out) dev_ctx.template Alloc<T>(qkv_bias_out);
if (transpose_out_2) dev_ctx.template Alloc<T>(transpose_out_2);
if (qk_out) dev_ctx.template Alloc<T>(qk_out);
if (qktv_out) dev_ctx.template Alloc<T>(qktv_out);
if (softmax_out) dev_ctx.template Alloc<T>(softmax_out);
if (attn_dropout_mask_out)
dev_ctx.template Alloc<uint8_t>(attn_dropout_mask_out);
if (attn_dropout_out) dev_ctx.template Alloc<T>(attn_dropout_out);
if (src_mask_out) dev_ctx.template Alloc<T>(src_mask_out);
if (fmha_out) dev_ctx.template Alloc<T>(fmha_out);
if (out_linear_out) dev_ctx.template Alloc<T>(out_linear_out);
if (dropout_mask_out) dev_ctx.template Alloc<uint8_t>(dropout_mask_out);
if (ln_mean_2) dev_ctx.template Alloc<U>(ln_mean_2);
if (ln_var_2) dev_ctx.template Alloc<U>(ln_var_2);
if (bias_dropout_residual_out)
dev_ctx.template Alloc<T>(bias_dropout_residual_out);
if (cache_kv_out) dev_ctx.template Alloc<T>(cache_kv_out);
dev_ctx.template Alloc<T>(out);
return;
}

// x: qkv's input [batch_size, seq_len, dim_embed]
// if transpose_qkv_wb is False
Expand Down Expand Up @@ -347,6 +373,19 @@ void FusedAttentionKernel(const Context &dev_ctx,
dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
U *ln_var_2_ptr =
dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));

// 0-size
if (ln_scale_2_p && ln_scale_2_p->numel() == 0) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(dev_ctx,
out_linear_out_data,
residual_ptr,
out_linear_bias_data,
final_out_data,
dropout_mask_out_data);
return;
}

// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
dev_ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void FusedGemmEpilogueGradKernel(
DenseTensor* y_grad,
DenseTensor* bias_grad) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
dev_ctx.template Alloc<T>(y_grad);
phi::FullKernel<T>(
dev_ctx, common::vectorize(y.dims()), 0.0, y.dtype(), y_grad);
Expand Down
33 changes: 32 additions & 1 deletion test/legacy_test/test_fused_attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ def GetBaselineOut(self):
)
out = self.out_proj(out_linear_in)

residual_out = residual + self.dropout(out)
if out.size == 0:
residual_out = residual
else:
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
final_out = self.norm1(residual_out)
else:
Expand Down Expand Up @@ -731,5 +734,33 @@ def test_fused_attention_op(self):
)


class TestFusedAttentionOp_ZeroSize(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = False
self.has_cache_kv = False
self.training = True

self.batch_size = 0 # 0-size
self.query_length = 128
self.cache_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads

self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = (
self.query_length,
self.query_length,
)
self.transpose_qkv_wb = False


if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions test/legacy_test/test_fused_gemm_epilogue_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,45 @@ def test_case_act(self):
paddle.enable_static()


@unittest.skipIf(
not is_fused_gemm_epilogue_supported(),
"fused_gemm_epilogue is only supported when CUDA version >= 11.6",
)
class TestEagerFusedGemmEpilogue_ZeroSize(unittest.TestCase):
def setUp(self):
paddle.set_device('gpu')

def test_case_act(self):
paddle.disable_static()
x_np = np.random.random((0, 4)).astype(np.float64) - 0.5
y_np = np.random.random((4, 128)).astype(np.float64) - 0.5
bias_np = np.random.random((128,)).astype(np.float64) - 0.5
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
y.stop_gradient = False

out1 = fused_linear_activation(x, y, bias, False, False, 'none')
out_np1 = get_output(x_np, y_np, bias_np, 'none')
np.testing.assert_allclose(out1, out_np1, rtol=1e-05)
out_grad_np1 = np.random.randint(
low=-20, high=20, size=out_np1.shape
).astype(np.float64)
paddle.autograd.backward(
out1, grad_tensors=[paddle.to_tensor(out_grad_np1)]
)

x_grad_np, y_grad_np, bias_grad_np = matmul_grad(
x_np, y_np, bias_np, out_grad_np1, False, False
)
np.testing.assert_allclose(x.grad.numpy(), x_grad_np, rtol=1e-05)
self.assertEqual(y_grad_np.shape, y_np.shape)
np.testing.assert_allclose(y.grad.numpy(), y_grad_np, rtol=1e-05)

paddle.enable_static()


if __name__ == "__main__":
paddle.enable_static()
np.random.seed(0)
Expand Down
Loading