Skip to content

Conversation

hebangwen
Copy link
Contributor

Introduction

Thanks for sharing this awesome code. Recently, I find float16 will overflow in layer_norm and rms_norm and get zero after dividing infinity. That's why regular rms-norm impl will cast inputs to float32 and truncate to f16 after norm, e.g. Qwen3RMSNorm. Results are shown as below.

Plus, LANUCH_RMS_NORM_F16x8F32_KERNEL accidently called rms_norm_f16x8_f16_kernel, which causes f32 kernel overflows, too. So I change the name back.

Contents

  1. Add f16 overflow testcase
  2. Correct the kernel typo for rmsnorm_f16x8_f32. It was rmsnorm_f16x8_f16 before fix.
  3. Correct more typo in LayerNorm & RMSNorm kernel in pull/317.

logs

GPU: RTX 3060Ti Laptop

# layernorm ------------------------------------------------------------------------------------- N=4096, K=512 ------------------------------------------------------------------------------------- out_f32: ['1.13426268 ', '-0.44016451 ', '0.1694629 '], time:0.08725476ms out_f32x4: ['1.13426256 ', '-0.44016448 ', '0.16946289 '], time:0.06236434ms out_f32_th: ['1.13315964 ', '-0.43973646 ', '0.16929811 '], time:0.32872820ms ------------------------------------------------------------------------------------- out_f16f16: ['1.13476562 ', '-0.44042969 ', '0.16967773 '], time:0.08019090ms out_f16f32: ['1.13378906 ', '-0.44042969 ', '0.16955566 '], time:0.07983351ms out_f16x2f16: ['1.13378906 ', '-0.44018555 ', '0.16943359 '], time:0.03940082ms out_f16x8f16: ['1.13378906 ', '-0.44018555 ', '0.16943359 '], time:0.03368807ms out_f16x8packf16: ['1.13378906 ', '-0.44018555 ', '0.16955566 '], time:0.03352642ms out_f16x8packf32: ['1.13378906 ', '-0.44042969 ', '0.16955566 '], time:0.03344059ms out_f16_th: ['1.1328125 ', '-0.43969727 ', '0.16931152 '], time:0.18646502ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- f16 overflow without f32 ------------------------------------------------------------------------------------- out_f16f16: ['0.0 ', '0.0 ', '0.0 '], time:0.07995677ms out_f16f32: ['1.13378906 ', '-0.44018555 ', '0.16943359 '], time:0.08067250ms out_f16x2f16: ['0.0 ', '0.0 ', '0.0 '], time:0.04753709ms out_f16x8f16: ['0.0 ', '0.0 ', '0.0 '], time:0.03417683ms out_f16x8packf16: ['0.0 ', '0.0 ', '0.0 '], time:0.03382993ms out_f16x8packf32: ['1.13378906 ', '-0.44018555 ', '0.16943359 '], time:0.03370309ms out_f16_th: ['1.1328125 ', '-0.43994141 ', '0.16931152 '], time:0.18657756ms ------------------------------------------------------------------------------------- # rmsnorm ------------------------------------------------------------------------------------- N=4096, K=512 out_f32: ['0.24011645 ', '-0.61607343 ', '2.22712135 '], time:0.07045746ms out_f32x4: ['0.24011645 ', '-0.61607343 ', '2.22712135 '], time:0.06801510ms out_f32_th: ['0.24011768 ', '-0.61607659 ', '2.2271328 '], time:0.21679926ms ------------------------------------------------------------------------------------- out_f16f16: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.05328727ms out_f16f32: ['0.2401123 ', '-0.61572266 ', '2.2265625 '], time:0.05257893ms out_f16x2f16: ['0.24023438 ', '-0.61621094 ', '2.22851562 '], time:0.03612614ms out_f16x8f16: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.03403950ms out_f16x8f32: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.03377581ms out_f16x8packf16: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.03370738ms out_f16x8packf32: ['0.2401123 ', '-0.61572266 ', '2.2265625 '], time:0.03313088ms out_f16_th: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.12258506ms ------------------------------------------------------------------------------------- f16 overflow without f32 ------------------------------------------------------------------------------------- out_f16f16: ['0.0 ', '-0.0 ', '0.0 '], time:0.05549264ms out_f16f32: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.05216908ms out_f16x2f16: ['0.0 ', '-0.0 ', '0.0 '], time:0.03391409ms out_f16x8f16: ['0.0 ', '-0.0 ', '0.0 '], time:0.03273296ms out_f16x8f32: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.03380895ms out_f16x8packf16: ['0.0 ', '-0.0 ', '0.0 '], time:0.03445339ms out_f16x8packf32: ['0.23999023 ', '-0.61572266 ', '2.2265625 '], time:0.03538585ms out_f16_th: ['0.0 ', '-0.0 ', '0.0 '], time:0.12091827ms -------------------------------------------------------------------------------------
@DefTruth DefTruth self-requested a review June 10, 2025 03:10
Copy link
Member

@DefTruth DefTruth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~, thanks for this fix!

@DefTruth DefTruth changed the title [Testcase & BugFix] Add f16 overflow testcase & Fix kernel name for rmsnorm_f16x8_f32 bugfix: Add layernorm & rmsnorm f16 overflow Jun 10, 2025
@DefTruth DefTruth changed the title bugfix: Add layernorm & rmsnorm f16 overflow bugfix: fix layernorm & rmsnorm f16 overflow Jun 10, 2025
@DefTruth DefTruth merged commit ace5f16 into xlite-dev:main Jun 10, 2025
@hebangwen hebangwen deleted the test/f16_overflow branch June 10, 2025 04:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants