Skip to content

Conversation

@Liu-xiandong
Copy link
Member

PR types

New features

PR changes

APIs

Describe

cherry-pick #PR35757
Add paddle.nn.functional.sparse_attention API

  • 本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

  • 此外,对于封装的python 接口,增加了相应的单测。

Example

import paddle import numpy as np query_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32") key_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32") value_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32") sparse_csr_offset_data = np.array([[[0, 2, 4, 6, 8]]]).astype("int32") sparse_csr_columns_data = np.array([[[0, 1, 0, 1, 2, 3, 2, 3]]]).astype("int32") print(query_data.shape) # (1, 1, 4, 2) print(sparse_csr_offset_data.shape) # (1, 1, 5) print(sparse_csr_columns_data.shape) # (1, 1, 8) paddle.disable_static() query = paddle.to_tensor(query_data, stop_gradient=False, place=paddle.CUDAPlace(0)) key = paddle.to_tensor(key_data, stop_gradient=False, place=paddle.CUDAPlace(0)) value = paddle.to_tensor(value_data, stop_gradient=False, place=paddle.CUDAPlace(0)) offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, place=paddle.CUDAPlace(0)) columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, place=paddle.CUDAPlace(0)) output = paddle.nn.functional.sparse_attention(query, key, value, offset, columns) print(output) # [[[[1.60885942, 2.60885954], # [1.99830270, 2.99830270], # [1.60885942, 2.60885954], # [1.99830270, 2.99830270]]]] 

Result

由于目前CI平台没有CUDA11.2的机器资源,因而将本地计算结果粘贴如下:

  1. 本地单测结果
    image
  2. API example结果
    image
…addlePaddle#35757) Add paddle.nn.functional.sparse_attention API 本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676 此外,对于封装的python 接口,增加了相应的单测。
@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@lanxianghit lanxianghit merged commit c57d1e9 into PaddlePaddle:release/2.2 Oct 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants