Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4508,6 +4508,20 @@ void VariableLengthMemoryEfficientAttentionInferMeta(
true,
common::errors::InvalidArgument(
"The seq length of Key, Value should be equal."));
if (mask) {
PADDLE_ENFORCE_EQ(
mask.dims().size(),
4,
common::errors::InvalidArgument("Mask should be a 4-D tensor"
"But received Value dimension(%s)",
mask.dims().size()));
const int64_t mask_batch_size = mask.dims()[0];
PADDLE_ENFORCE_EQ(
query_batch_size == mask_batch_size,
true,
common::errors::InvalidArgument(
"The batch size of Query, Key, Value and Mask should be equal."));
}

std::vector<int64_t> out_dims(
{query_batch_size, query_num_head, query_seq_length, value_head_size});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ void MultiHeadAttentionVariableForwardKernel(
params.causal = causal;
params.pre_cache_length = pre_cache_length;

if (mask) {
// if the mask is 0-size tensor, we don't need to set mask_ptr
if (mask && mask.get().numel() > 0) {
// [B, 1, S, D]
auto mask_tensor = mask.get();
int64_t mask_num_heads = mask_tensor.dims()[1];
Expand Down
11 changes: 5 additions & 6 deletions paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ void FusedRopeGradKernel(const Context& dev_ctx,
DenseTensor* dk,
DenseTensor* dv) {
int64_t numel = dout_q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(dq);
if (dout_k) dev_ctx.template Alloc<T>(dk);
if (dout_v) dev_ctx.template Alloc<T>(dv);
if (numel <= 0) return;

phi::Array<int64_t, 3> inputs_num_heads;
// small size for broadcast
Expand Down Expand Up @@ -70,22 +72,19 @@ void FusedRopeGradKernel(const Context& dev_ctx,
outs_data[0] = dq->data<T>();
int num_inputs = 1;

if (dout_k) {
dev_ctx.template Alloc<T>(dk);
if (dk && dk->numel() > 0) {
outs_data[num_inputs] = dk->data<T>();
ins_data[num_inputs] = dout_k->data<T>();
inputs_num_heads[num_inputs] = dk->dims()[2];
num_inputs++;
}

if (dout_v) {
dev_ctx.template Alloc<T>(dv);
if (dv && dv->numel() > 0) {
outs_data[num_inputs] = dv->data<T>();
ins_data[num_inputs] = dout_v->data<T>();
inputs_num_heads[num_inputs] = dv->dims()[2];
num_inputs++;
}

using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType div_c = static_cast<MPType>(1.0f / head_dim);

Expand Down
11 changes: 5 additions & 6 deletions paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ void FusedRopeKernel(const Context& dev_ctx,
DenseTensor* out_k,
DenseTensor* out_v) {
int64_t numel = q.numel();
if (numel <= 0) return;
dev_ctx.template Alloc<T>(out_q);
if (k) dev_ctx.template Alloc<T>(out_k);
if (v) dev_ctx.template Alloc<T>(out_v);
if (numel <= 0) return;

phi::Array<int64_t, 3> inputs_num_heads;

Expand Down Expand Up @@ -73,16 +75,13 @@ void FusedRopeKernel(const Context& dev_ctx,
outs_data[0] = out_q->data<T>();
int num_inputs = 1;

if (k) {
dev_ctx.template Alloc<T>(out_k);
if (out_k && out_k->numel() > 0) {
ins_data[num_inputs] = k->data<T>();
outs_data[num_inputs] = out_k->data<T>();
inputs_num_heads[num_inputs] = k->dims()[2];
num_inputs++;
}

if (v) {
dev_ctx.template Alloc<T>(out_v);
if (out_v && out_v->numel() > 0) {
ins_data[num_inputs] = v->data<T>();
outs_data[num_inputs] = out_v->data<T>();
inputs_num_heads[num_inputs] = v->dims()[2];
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,18 @@ void FlashAttnGradKernel(const Context& dev_ctx,
if (dv) {
dev_ctx.template Alloc<T>(dv);
}
if (dout.numel() == 0) {
if (dq)
Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dq->dims())), 0, dq);
if (dk)
Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dk->dims())), 0, dk);
if (dv)
Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dv->dims())), 0, dv);
return;
}
FlashAttnGradBaseKernel<T, Context>(dev_ctx,
q,
k,
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/kernels/gpu/flash_attn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,31 @@ void FlashAttnKernel(const Context& dev_ctx,
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
if (q.numel() == 0 || k.numel() == 0 || v.numel() == 0) {
if (out) {
Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
}
if (softmax) {
Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(softmax->dims())),
0,
softmax);
}
if (softmax_lse) {
Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(softmax_lse->dims())),
0,
softmax_lse);
}
if (seed_offset) {
Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(seed_offset->dims())),
0,
seed_offset);
}
return;
}
FlashAttnBaseKernel<T, Context>(dev_ctx,
q,
k,
Expand Down
50 changes: 50 additions & 0 deletions test/legacy_test/test_fused_rotary_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,5 +692,55 @@ def test_error2():
self.assertRaises(AssertionError, test_error2)


@unittest.skipIf(
not core.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(),
"core is not compiled with CUDA or ROCM ",
)
class TestFusedRotaryPositionEmbeddingZeroSize(unittest.TestCase):
def setUp(self):
self.dtype = "float32"
self.qkv_shape = [0, 1, 8, 8]
self.sin_cos_shape = [1, 1, 1, 8]

def init_data(self):
self.q = paddle.randn(self.qkv_shape, dtype=self.dtype)
self.k = paddle.randn(self.qkv_shape, dtype=self.dtype)
self.v = paddle.randn(self.qkv_shape, dtype=self.dtype)
self.q.stop_gradient = False
self.k.stop_gradient = False
self.v.stop_gradient = False
self.sin = paddle.sin(
paddle.randn(self.sin_cos_shape, dtype=self.dtype)
)
self.cos = paddle.cos(
paddle.randn(self.sin_cos_shape, dtype=self.dtype)
)

def _test_forward_backward(self):
out_q, out_k, out_v = fused_rotary_position_embedding(
self.q,
self.k,
self.v,
sin=self.sin,
cos=self.cos,
use_neox_rotary_style=False,
)
out = out_q + out_k + out_v
out.backward()
np.testing.assert_allclose(
self.q.shape, self.q.grad.shape, rtol=1e-05, atol=1e-06
)
np.testing.assert_allclose(
self.k.shape, self.k.grad.shape, rtol=1e-05, atol=1e-06
)
np.testing.assert_allclose(
self.v.shape, self.v.grad.shape, rtol=1e-05, atol=1e-06
)

def test_zero_size(self):
self.init_data()
self._test_forward_backward()


if __name__ == "__main__":
unittest.main()
9 changes: 9 additions & 0 deletions test/legacy_test/test_scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,14 @@ def test_3d_input(self):
np.testing.assert_allclose(out.numpy(), out_ref, rtol=5e-03, atol=1e-03)


class TestAttentionWithBoolMaskZeroSize(TestAttentionWithBoolMask):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (0, 1, 8, 8)
self.dtype = 'float32'
self.dropout = 0.0
self.causal = False


if __name__ == '__main__':
unittest.main()
Loading