Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4487,14 +4487,17 @@ void VariableLengthMemoryEfficientAttentionInferMeta(
common::errors::InvalidArgument(
"The head number of Key, Value should be equal."));

PADDLE_ENFORCE_EQ(
query_num_head % key_num_head,
0,
errors::InvalidArgument(
"The num_head of query must be divisible by the num_head of key, but "
"received num_head of query is %d, and the num_head of key is %d",
query_num_head,
key_num_head));
if (key_num_head != 0) {
PADDLE_ENFORCE_EQ(
query_num_head % key_num_head,
0,
errors::InvalidArgument(
"The num_head of query must be divisible by the num_head of key, "
"but "
"received num_head of query is %d, and the num_head of key is %d",
query_num_head,
key_num_head));
}

PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ void MultiHeadAttentionVariableForwardKernel(
const int pre_cache_length,
DenseTensor* output) {
dev_ctx.template Alloc<T>(output);
if (output->numel() == 0) return;

Params params{};
// [B, N, S, H]
params.seq_lens = seq_lens.data<int>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,5 +341,79 @@ def test_all(self):
np.testing.assert_allclose(res[0], self.ref_out, rtol=5e-03, atol=1e-03)


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.2",
)
class TestMemEffAttentionVariableAPI_ZeroSize(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPIVariable_fp32"
self.place = paddle.CUDAPlace(0)
self.batch_size = 0
self.num_head = 8
self.kv_num_head = 2
self.seq_len = 64
self.dim_head = 16
self.seq_lens = paddle.to_tensor(
[
self.seq_len,
]
* self.batch_size,
"int32",
)
self.shape = (
self.batch_size,
self.num_head,
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float32'
self.attention_mask = paddle.zeros([0, 1, 64, 1])
self.scale = 1.0 / np.sqrt(self.shape[-1])

def test_all(self):
paddle.disable_static()

query = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
q.stop_gradient = False
key = np.random.random(self.shape_kv)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
value = np.random.random(self.shape_kv)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

out_ = naive_attention_impl(q, k, v, self.attention_mask, self.scale)

out = variable_length_memory_efficient_attention(
q,
k,
v,
self.seq_lens,
self.seq_lens,
self.attention_mask,
self.scale,
)

for i in range(self.batch_size):
np.testing.assert_allclose(
out.numpy()[i, :, : self.seq_lens[i], :],
out_[i, :, : self.seq_lens[i], :],
rtol=5e-03,
atol=1e-03,
)


if __name__ == '__main__':
unittest.main()
Loading