Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
<div align='center'>
<img src='https://github.com/user-attachments/assets/a5ec4320-d2f9-4254-888a-170b2d9e3784' height=170px>
</div>
-->
-->

- [2025-01-08]: **[🤖ffpa-attn](https://github.com/xlite-dev/ffpa-attn.git)** is released! Yet another Faster Flash Prefill Attention with O(1)🎉SRAM complexity for large headdim, **1.8x~3x↑**🎉 vs SDPA EA: [📈L20 ~1.9x↑🎉](https://github.com/xlite-dev/ffpa-attn?tab=readme-ov-file#L1-bench-l20), [📈A30 ~1.8x↑🎉](https://github.com/xlite-dev/ffpa-attn?tab=readme-ov-file#L1-bench-a30),[📈4090 ~2.1x↑🎉](https://github.com/xlite-dev/ffpa-attn?tab=readme-ov-file#L1-bench-4090).

Expand All @@ -54,7 +54,7 @@
<img src='https://github.com/user-attachments/assets/447e2937-f7c8-47c8-8550-8c0c71b910e6' height="170px" width="229px">
<img src='https://github.com/user-attachments/assets/65a8d564-8fa7-4d66-86b9-e238feb86143' height="170px" width="229px">
</div>
-->
-->
<div align='center'>
<img height="320px" alt="image" src="https://github.com/user-attachments/assets/ed30185b-2e11-4293-832f-43e9003d6ad9" />
</div>
Expand Down
12 changes: 6 additions & 6 deletions kernels/softmax/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,14 @@ __global__ void online_safe_softmax_f32_per_token_kernel(const float *x,
// for softmax)
int local_tid = threadIdx.x;
int global_tid = blockIdx.x * NUM_THREADS + threadIdx.x;
const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
const int WARP_NUM = NUM_THREADS / WARP_SIZE;
int warp_id = local_tid / WARP_SIZE;
int lane_id = local_tid % WARP_SIZE;
MD val;
val.m = global_tid < N ? x[global_tid] : -FLT_MAX;
val.d = global_tid < N ? 1.0f : 0.0f;

__shared__ MD shared[WAPR_NUM];
__shared__ MD shared[WARP_NUM];
MD res = warp_reduce_md_op<WARP_SIZE>(val);

if (lane_id == 0)
Expand All @@ -349,7 +349,7 @@ __global__ void online_safe_softmax_f32_per_token_kernel(const float *x,

if (local_tid < WARP_SIZE) {
MD block_res = shared[local_tid];
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
block_res = warp_reduce_md_op<WARP_NUM>(block_res);
if (local_tid == 0) {
shared[0] = block_res;
}
Expand All @@ -371,7 +371,7 @@ online_safe_softmax_f32x4_pack_per_token_kernel(float *x, float *y, int N) {
int local_tid = threadIdx.x;
int global_tid = (blockIdx.x * NUM_THREADS + local_tid) * 4;

const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
const int WARP_NUM = NUM_THREADS / WARP_SIZE;
int warp_id = local_tid / WARP_SIZE;
int lane_id = local_tid % WARP_SIZE;
// compare local max value
Expand All @@ -382,15 +382,15 @@ online_safe_softmax_f32x4_pack_per_token_kernel(float *x, float *y, int N) {

MD local_md = {local_m, local_d};
MD res = warp_reduce_md_op<WARP_SIZE>(local_md);
__shared__ MD shared[WAPR_NUM];
__shared__ MD shared[WARP_NUM];

if (lane_id == 0)
shared[warp_id] = res;
__syncthreads();
// do block reduce
if (local_tid < WARP_SIZE) {
MD block_res = shared[local_tid];
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
block_res = warp_reduce_md_op<WARP_NUM>(block_res);
if (local_tid == 0)
shared[0] = block_res;
}
Expand Down