Skip to content

Conversation

@limin2021
Copy link
Contributor

PR types

New features

PR changes

OPs

Describe

  1. 功能:本PR的目标是提高attention模块的计算性能。
    为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
    为了减少防存开销,本PR采取了两种优化方法:
    (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
    (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;

  2. fused_attention_op 实现的计算逻辑:
    image

  3. fused_attention_op与paddle已有的MultiHeadAttention layer的不同:
    (1)计算逻辑范围扩大了,详见上面的伪代码。
    (2)q, k, v的weight存储格式不一样。
    原有的:保存在三个weight张量中,WQ, WK, WV
    本PR:保存在一个weight张量中,qkv_weight
    由WQ, WK, WV得到qkv_weight的方法:
    image

  4. 实现:
    本PR是fused_attention_op 的前向实现,具体细节:

(1) fused_attention_op.cc and fused_attention_op.cu
The C++ forward impl for fused_attention_op. The impl uses these PRs:
#34883, #35308, #35350 #35621 , #35903, #36185

(2) functional/fused_transformer.py
The python api for fused_attention_op.
Here, it only include dynamic graph api,
the static graph api will be added in the next PR.

(3) test_fused_attention_op.py
The unittest script for fused_attention_op: dynamic, forward;

(4) paddle/fluid/operators/dropout_impl_util.h
Modifications of contents of dropout_impl_util.h in #35820 is overlapped by #36185.
In this PR, we recovered the contents to be same as #35820.

(5) Fix bugs: remove useless "print" in framework.py.

  1. Unittest:
    756c4cbdab1aa8507ced0ef3cc48ccdb
功能:本PR的目标是提高attention模块的计算性能。 为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op; 为了减少防存开销,本PR采取了两种优化方法: (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次; (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
@limin2021 limin2021 changed the title [cherry-pick-2.2] Fused attention op forward (#35905) [cherry-pick-2.2] Fused attention op forward Oct 25, 2021
@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lanxianghit lanxianghit merged commit d2be870 into PaddlePaddle:release/2.2 Oct 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants