Skip to content

Conversation

@bythew3i
Copy link
Contributor

@bythew3i bythew3i commented Mar 5, 2025

Tested:

python test/test_pallas.py -v -k PallasTest.test_ragged_paged_attention_wrapper 

Please Read

This PR adds validation of ragged attn inputs to torch.ops.xla.ragged_paged_attention and expect to run it during runtime. Please move the validation code out if we have to compile something like (or just avoid compiling this).

def ragged_paged_attention_wrapper(...): ... return torch.ops.xla.ragged_paged_attention(...) compiled_paged_attention = torch.compile( ragged_paged_attention_wrapper, backend="openxla") 

Key Features in Ragged Paged Attention V2

  • Support mixed prefill and decode to increase throughput for inference. (eg., 5x speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
  • No explicit swapaxes for seq_len and num_head in pre/post kernel. The kernel takes num_head in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
  • No GMM (Grouped Matmul) Metadata required! We calculate the metadata on the fly in the kernel. This can speed up 10%!
  • Increase MXU utilization 8x in GQA by grouping shared q heads for MXU in decode.
  • Minimize recompilation: The only factors can cause recompilation are model specs, max_num_batched_tokens and max_num_seqs in the setting of mixed engine.

Note: this PR does not include tests for Ragged Paged Attention kernel. Because it is already tested in jax-ml/jax#26920 and we will directly import it as source instead of keep duplicated implementations in the future.

@yaochengji yaochengji self-requested a review March 5, 2025 17:49
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@yaochengji yaochengji enabled auto-merge (squash) March 5, 2025 17:50
@yaochengji yaochengji merged commit 5644f44 into pytorch:master Mar 5, 2025
22 of 23 checks passed
pgmoka pushed a commit that referenced this pull request Mar 5, 2025
@zpcore
Copy link
Member

zpcore commented Mar 5, 2025

The test test_ragged_paged_attention_wrapper_with_padding_with_dynamo2 is failing. Can someone help make a fix? Thanks

@yaochengji
Copy link
Collaborator

@zpcore , thanks, it is fixed in #8797

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants