Skip to content

Conversation

@limin2021
Copy link
Contributor

@limin2021 limin2021 commented Jan 26, 2022

PR types

Performance optimization

PR changes

OPs

Describe

(1) Optimize layer norm baward cuda kernel when cols is 1024.
(1) Optimize fused_dropout_residual_layer_norm op backward kernel when cols is 1024.

Performance results:
(1) for layer_norm op backward kernel:

bsz 28672 14336 7168 3584 1792 896 448 224 112 56
时间(ns):                    
apex_fast_layer_norm 118787.33 69095.6258 40478.324 22281.857 16425.4798 17290.8433 15537.1818 14625.346 14277.9958 14471.547
paddle 248873.331 135595.776 75483.419 45307.2175 27301.4258 21960.7725 18762.2448 18727.795 19095.8335 18361.035
paddle_opt 117072.924 67497.807 38579.6438 21669.2123 16428.6753 14495.7083 13712.1905 14245.3828 14196.2095 14095.0285
加速比:                    
apex/paddle 0.47730036 0.50957064 0.53625451 0.49179487 0.6016345 0.78735132 0.82810889 0.7809433 0.74770215 0.78816619
apex/paddle_opt 1.01464391 1.02367216 1.04921456 1.02827259 0.99980549 1.192825 1.13309261 1.02667273 1.00576113 1.026712858

结论:优化后相比优化前获得20%-50%性能提升;优化后相比竞品基本持平,有些case略有提升。

(2) for fused_dropout_residual_ln op backward kernel:

batch_size*seq_len 28672 14336 7168 3584 1792 896 448 224 112 56
时间:ns                    
nv-mlperf-1.1 219413.3928 118066.406 66881.7083 42488.2435 28121.3475 25126.0745 21853.752 19989.3635 19478.7995 19143.8193
paddle-opt 178724.182 96127.955 52758.1943 28410.584 22318.3545 18894.0908 16635.9843 15426.4945 15144.0065 15136.8995
加速比:                    
nv/paddle-opt 1.227664831 1.22822135 1.26770276 1.49550757 1.26000989 1.32983772 1.31364347 1.29578133 1.28623819 1.26471205

结论:相比竞品获得大约20%左右提升(因为竞品只是将dropout和dresidual计算融合,并没有融合layer_norm_grad,一共有3个kernel;优化后我们只有2个kernel)。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@limin2021 limin2021 changed the title Optimize layer norm baward cuda kernel when cols is 1024. Optimize layer norm backward cuda kernel when cols is 1024. Jan 26, 2022
@limin2021 limin2021 requested a review from zkh2016 January 29, 2022 05:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants