-   Notifications  You must be signed in to change notification settings 
- Fork 5.9k
Add nn.functional.sparse_attention and some test cases, test=develop #35757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
089e738 fd26a39 d2c0cc4 47828d2 c21c203 9be4aac 3299354 1bf2601 23d6756 64599d1 0ea0f11 a967b86 b49797d File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -16,10 +16,13 @@ | |
| import numpy as np | ||
| from op_test import OpTest | ||
| import paddle.fluid.core as core | ||
| from paddle.static import Program, program_guard | ||
| import paddle | ||
| import paddle.fluid as fluid | ||
| import paddle.fluid.framework as framework | ||
| import paddle.nn.functional as F | ||
| import os | ||
| import re | ||
| import platform | ||
|  | ||
|  | ||
| def get_cuda_version(): | ||
|  | @@ -34,22 +37,6 @@ def get_cuda_version(): | |
| return -1 | ||
|  | ||
|  | ||
| def get_linux_platform(): | ||
| if platform.system().lower() == 'windows': | ||
| return 0 | ||
| elif platform.system().lower() == 'linux': | ||
| return 1 | ||
| else: | ||
| return -1 | ||
|  | ||
|  | ||
| def get_suitable_env(): | ||
| if get_cuda_version() >= 11020 and get_linux_platform() == 1: | ||
| return True | ||
| else: | ||
| return False | ||
|  | ||
|  | ||
| def softmax(x): | ||
| max = np.max(x, axis=1, keepdims=True) | ||
| e_x = np.exp(x - max) | ||
|  | @@ -141,8 +128,9 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): | |
|  | ||
|  | ||
| @unittest.skipIf( | ||
| not core.is_compiled_with_cuda() or get_suitable_env() == False, | ||
| "core is not compiled with CUDA and cuda version need >= 11.2 in windows") | ||
| not core.is_compiled_with_cuda() or get_cuda_version() < 11020, | ||
| "core is not compiled with CUDA and cuda version need larger than or equal to 11.2" | ||
| ) | ||
| class TestSparseAttentionOp(OpTest): | ||
| def config(self): | ||
| self.shape = (1, 1, 16, 8) | ||
|  | @@ -201,5 +189,130 @@ def config(self): | |
| self.dtype = "float64" | ||
|  | ||
|  | ||
| @unittest.skipIf( | ||
| not core.is_compiled_with_cuda() or get_cuda_version() < 11020, | ||
| "core is not compiled with CUDA and cuda version need larger than or equal to 11.2" | ||
| ) | ||
| class TestSparseAttentionAPI(unittest.TestCase): | ||
| def setUp(self): | ||
| self.place = paddle.CUDAPlace(0) | ||
| self.shape = (1, 1, 8, 4) | ||
| self.blocksize = 2 | ||
| self.dtype = 'float64' | ||
|  | ||
| def test_static_graph(self): | ||
| paddle.enable_static() | ||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| Q = paddle.static.data(name="Q", shape=self.shape, dtype=self.dtype) | ||
| K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype) | ||
| V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype) | ||
|  | ||
| batch_size, num_heads, rows = self.shape[0], self.shape[ | ||
| 1], self.shape[2] | ||
| block_num = rows / self.blocksize | ||
| block_last = rows % self.blocksize | ||
| sparse_nnz_num = block_num * self.blocksize * self.blocksize + block_last * block_last | ||
| offset_shape = (batch_size, num_heads, rows + 1) | ||
| columns_shape = (batch_size, num_heads, int(sparse_nnz_num)) | ||
|  | ||
| offset = paddle.static.data( | ||
| name="Offset", shape=offset_shape, dtype="int32") | ||
| columns = paddle.static.data( | ||
| name="Columns", shape=columns_shape, dtype="int32") | ||
| Out = F.sparse_attention(Q, K, V, offset, columns) | ||
|  | ||
| Q_np = np.random.random(self.shape).astype(self.dtype) | ||
| K_np = np.random.random(self.shape).astype(self.dtype) | ||
| V_np = np.random.random(self.shape).astype(self.dtype) | ||
| offset_np, columns_np = init_csr_format( | ||
| self.shape[0], self.shape[1], self.shape[2], self.blocksize) | ||
| offset_np = offset_np.astype('int32') | ||
| columns_np = columns_np.astype('int32') | ||
|  | ||
| exe = fluid.Executor(self.place) | ||
| fetches_result = exe.run(feed={ | ||
| "Q": Q_np, | ||
| "K": K_np, | ||
| "V": V_np, | ||
| "Offset": offset_np, | ||
| "Columns": columns_np | ||
| }, | ||
| fetch_list=[Out]) | ||
| expected_result, __, __ = ref_batch_sparse_attention( | ||
| Q_np, K_np, V_np, offset_np, columns_np) | ||
|  | ||
| self.assertTrue( | ||
| np.allclose( | ||
| fetches_result, expected_result, atol=1e-5)) | ||
|  | ||
| def test_dygraph(self): | ||
| paddle.disable_static() | ||
| offset, columns = init_csr_format(self.shape[0], self.shape[1], | ||
| self.shape[2], self.blocksize) | ||
| offset = offset.astype('int32') | ||
| columns = columns.astype('int32') | ||
| query = np.random.random(self.shape).astype(self.dtype) | ||
| key = np.random.random(self.shape).astype(self.dtype) | ||
| value = np.random.random(self.shape).astype(self.dtype) | ||
|  | ||
| paddle_query = paddle.to_tensor(query, place=self.place) | ||
| paddle_key = paddle.to_tensor(key, place=self.place) | ||
| paddle_value = paddle.to_tensor(value, place=self.place) | ||
| paddle_offset = paddle.to_tensor(offset, place=self.place) | ||
| paddle_colunmns = paddle.to_tensor(columns, place=self.place) | ||
|  | ||
| paddle_result = F.sparse_attention(paddle_query, paddle_key, | ||
| paddle_value, paddle_offset, | ||
| paddle_colunmns) | ||
|  | ||
| numpy_result, __, __ = ref_batch_sparse_attention(query, key, value, | ||
| offset, columns) | ||
| numpy_result = numpy_result.astype(self.dtype) | ||
|  | ||
| self.assertTrue( | ||
| np.allclose( | ||
| paddle_result.numpy(), numpy_result, atol=1e-5)) | ||
|  | ||
|  | ||
| class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 非2的幂次方的 shape 可以支持吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以支持,单测中增加了非2的幂次方测试。 | ||
| def setUp(self): | ||
| self.place = paddle.CUDAPlace(0) | ||
| self.shape = (2, 2, 8, 4) | ||
| self.blocksize = 2 | ||
| self.dtype = 'float32' | ||
|  | ||
|  | ||
| class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): | ||
| def setUp(self): | ||
| self.place = paddle.CUDAPlace(0) | ||
| self.shape = (2, 2, 64, 32) | ||
| self.blocksize = 2 | ||
| self.dtype = 'float64' | ||
|  | ||
|  | ||
| class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): | ||
| def setUp(self): | ||
| self.place = paddle.CUDAPlace(0) | ||
| self.shape = (2, 1, 64, 32) | ||
| self.blocksize = 2 | ||
| self.dtype = 'float64' | ||
|  | ||
|  | ||
| class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): | ||
| def setUp(self): | ||
| self.place = paddle.CUDAPlace(0) | ||
| self.shape = (4, 4, 128, 32) | ||
| self.blocksize = 8 | ||
| self.dtype = 'float64' | ||
|  | ||
|  | ||
| class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): | ||
| def setUp(self): | ||
| self.place = paddle.CUDAPlace(0) | ||
| self.shape = (3, 3, 35, 15) | ||
| self.blocksize = 3 | ||
| self.dtype = 'float64' | ||
|  | ||
|  | ||
| if __name__ == '__main__': | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|  | ||
| import warnings | ||
| import paddle | ||
| from ...fluid.framework import in_dygraph_mode, default_main_program | ||
| from paddle.fluid.layer_helper import LayerHelper | ||
| from ...fluid.framework import in_dygraph_mode | ||
| from paddle import _C_ops | ||
|  | ||
|  | ||
| def sparse_attention(query, | ||
| key, | ||
| value, | ||
| sparse_csr_offset, | ||
| sparse_csr_columns, | ||
| name=None): | ||
| r""" | ||
| This operator sparsify the Attention matrix in Transformer module | ||
| to achieve the effect of reducing memory consumption and computation. | ||
| The sparse layout is expressed in CSR format and contains two parameters, | ||
| ``offset`` and ``columns``. | ||
|  | ||
| .. math:: | ||
|  | ||
| result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V | ||
|  | ||
| where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. | ||
| The dimensions of the three parameters are the same. | ||
| ``d`` represents the size of the last dimension of the three parameters. | ||
|  | ||
| Parameters: | ||
| query(Tensor): The query tensor in the Attention module. | ||
| It's a 4-D tensor with a shape of | ||
| :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. | ||
| The dtype can be ``float32`` and ``float64``. | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果是固定的4-D,描述可以直接指明维度,建议 : There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. | ||
| key(Tensor): The key tensor in the Attention module. | ||
| It's a 4-D tensor with a shape of | ||
| :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. | ||
| The dtype can be ``float32`` and ``float64``. | ||
| value(Tensor): The value tensor in the Attention module. | ||
| It's a 4-D tensor with a shape of | ||
| :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. | ||
| The dtype can be ``float32`` and ``float64``. | ||
| sparse_csr_offset(Tensor): The sparsity feature in the Attention module | ||
| is expressed in the CSR format, and the offset represents | ||
| the number of non-zero elements in each row of the matrix. | ||
| It's a 3-D tensor with a shape of | ||
| :math:`[batch\_size, num\_heads, seq\_len + 1]`. | ||
| The dtype should be ``int32``. | ||
| sparse_csr_columns(Tensor): The sparsity feature in the Attention module | ||
| is expressed in the CSR format, and the columns represent | ||
| the column index values of non-zero elements in the matrix. | ||
| It's a 3-D tensor with a shape of | ||
| :math:`[batch\_size, num\_heads, sparse\_nnz]`. | ||
| The dtype should be ``int32``. | ||
| name(str, optional): The default value is None. Normally there is no need for user | ||
| to set this property. For more information, please refer to | ||
| :ref:`api_guide_Name`. | ||
|  | ||
| Returns: | ||
| A Tensor which refers to the result in the Attention module. | ||
| It's a 4-D tensor with a shape of | ||
| :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. | ||
| The dtype can be ``float32`` and ``float64``. | ||
|  | ||
| Examples: | ||
| .. code-block:: python | ||
|  | ||
| # required: skiptest | ||
| import paddle | ||
| import numpy as np | ||
|  | ||
| query_data = np.array([[[[0, 1,], [2, 3], | ||
| [ 0, 1], [2, 3]]]]).astype("float32") | ||
| key_data = np.array([[[[0, 1,], [2, 3], | ||
| [ 0, 1], [2, 3]]]]).astype("float32") | ||
| value_data = np.array([[[[0, 1,], [2, 3], | ||
| [ 0, 1], [2, 3]]]]).astype("float32") | ||
| sparse_csr_offset_data = np.array([[[0, 2, | ||
| 4, 6, 8]]]).astype("int32") | ||
| sparse_csr_columns_data = np.array([[[0, 1, | ||
| 0, 1, 2, 3, 2, 3]]]).astype("int32") | ||
| print(query_data.shape) | ||
| # (1, 1, 4, 2) | ||
| print(sparse_csr_offset_data.shape) | ||
| # (1, 1, 5) | ||
| print(sparse_csr_columns_data.shape) | ||
| # (1, 1, 8) | ||
| paddle.disable_static() | ||
| query = paddle.to_tensor(query_data, stop_gradient=False, | ||
| place=paddle.CUDAPlace(0)) | ||
| key = paddle.to_tensor(key_data, stop_gradient=False, | ||
| place=paddle.CUDAPlace(0)) | ||
| value = paddle.to_tensor(value_data, stop_gradient=False, | ||
| place=paddle.CUDAPlace(0)) | ||
| offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, | ||
| place=paddle.CUDAPlace(0)) | ||
| columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, | ||
| place=paddle.CUDAPlace(0)) | ||
| output = paddle.nn.functional.sparse_attention(query, key, | ||
| value, offset, columns) | ||
| print(output) | ||
|  | ||
| # [[[[1.60885942, 2.60885954], | ||
| # [1.99830270, 2.99830270], | ||
| # [1.60885942, 2.60885954], | ||
| # [1.99830270, 2.99830270]]]] | ||
| """ | ||
| if in_dygraph_mode(): | ||
| result_attention, result_sdd, result_softmax = _C_ops.sparse_attention( | ||
| query, key, value, sparse_csr_offset, sparse_csr_columns) | ||
| return result_attention | ||
|  | ||
| helper = LayerHelper('sparse_attention', **locals()) | ||
| dtype = helper.input_dtype(input_param_name='Q') | ||
| out = helper.create_variable_for_type_inference(dtype) | ||
| result_sdd = helper.create_variable_for_type_inference(dtype) | ||
| result_softmax = helper.create_variable_for_type_inference(dtype) | ||
| inputs = { | ||
| 'Q': query, | ||
| 'K': key, | ||
| 'V': value, | ||
| 'Offset': sparse_csr_offset, | ||
| 'Columns': sparse_csr_columns | ||
| } | ||
| outputs = { | ||
| 'Out': out, | ||
| 'SparseDotSdd': result_sdd, | ||
| 'Softmax': result_softmax | ||
| } | ||
| helper.append_op(type='sparse_attention', inputs=inputs, outputs=outputs) | ||
| return out | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_static_result与test_dygraph这两个函数名,看着不像一对,如有必要,可以打磨一下命名。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.进行了更改