Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,12 +751,15 @@ def ragged_paged_attention_wrapper(
num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32)

from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
from torch_xla.experimental.custom_kernel import _get_default_ragged_paged_attention_block_size
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num = q.shape[0]
num_kv_pages_per_block, num_queries_per_block = _get_default_ragged_paged_attention_block_size(
token_num)
token_num, q_head_num, _ = q.shape
kv_head_num = kv_pages[2] // 2
max_model_len = 2048
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
q_head_num, kv_head_num, token_num, max_model_len)
jax_kernel_output = torch.from_numpy(
np.array(
jax_ragged_paged_attention(
Expand Down
32 changes: 6 additions & 26 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size

# Re-expose this API used that is referenced by docs
from torch_xla._internal.jax_workarounds import jax_import_guard # noqa: F401, pylint: disable=unused-import
Expand Down Expand Up @@ -915,29 +916,6 @@ def _ragged_paged_attention_nonkernel(
return torch.cat(outputs, dim=0)


def _get_default_ragged_paged_attention_block_size(token_num):
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
if tpu_version == 4:
# This default block size is not tuned, only make sure there's no
# OOM in vmem
num_kv_pages_per_block = 16
num_queries_per_block = 128
return num_kv_pages_per_block, num_queries_per_block

# This heristic is based on the initial kernel micro benchmarking:
# When the token_num is small, there's no long request of prefill.
# While when it's larger, the block size is adjusted for it.
if token_num <= 128:
num_kv_pages_per_block = 128
num_queries_per_block = 32
else:
num_kv_pages_per_block = 128
num_queries_per_block = 96
return num_kv_pages_per_block, num_queries_per_block


@requires_jax
def ragged_paged_attention(
q, # [max_num_batched_tokens, num_q_heads, head_dim]
Expand All @@ -952,6 +930,7 @@ def ragged_paged_attention(
soft_cap: float | None = None,
mask_value=None,
use_kernel=True,
max_model_len=2048, # Used as a hint for the kernel block sizes selection
# kernel tuning parameters
num_kv_pages_per_block=None,
num_queries_per_block=None,
Expand Down Expand Up @@ -980,9 +959,10 @@ def ragged_paged_attention(

if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num = q.shape[0]
num_kv_pages_per_block, num_queries_per_block = _get_default_ragged_paged_attention_block_size(
token_num)
token_num, q_head_num, _ = q.shape
kv_head_num = kv_pages[2] // 2
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
q_head_num, kv_head_num, token_num, max_model_len)

if vmem_limit_bytes is None:
vmem_limit_bytes = 64 * 1024 * 1024
Expand Down
79 changes: 79 additions & 0 deletions torch_xla/experimental/tuned_block_sizes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch_xla


def _next_power_of_2_bit_manipulation(x):
"""
Finds the smallest power of 2 >= x using bit manipulation.
Assumes x is an integer.

Args:
x: The input number (should be an integer).

Returns:
The smallest integer power of 2 that is >= x.
Returns 1 if x <= 0.
"""
if x <= 0:
return 1
if x == 1:
return 1
return 1 << (x - 1).bit_length()


# ragged_paged_attention
# key: (q_head_num, kv_head_num, token_num, max_model_len)
# value: (num_kv_pages_per_block, num_queries_per_block)


def _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num,
max_model_len):
token_num = _next_power_of_2_bit_manipulation(token_num)
max_model_len = _next_power_of_2_bit_manipulation(max_model_len)
return q_head_num, kv_head_num, token_num, max_model_len


# TODO: add more tuned block sizes in the table
_ragged_attention_table = {
(32, 8, 4096, 2048): (128, 64),
(4, 1, 4096, 2048): (128, 128),
(32, 8, 2048, 2048): (128, 32),
(4, 1, 2048, 2048): (128, 64),
(32, 8, 1024, 2048): (64, 32),
(1, 1, 1024, 2048): (64, 32),
(32, 8, 4096, 4096): (128, 64),
(4, 1, 4096, 4096): (128, 128),
(32, 8, 2048, 4096): (128, 32),
(4, 1, 2048, 4096): (128, 64),
(32, 8, 1024, 4096): (64, 32),
(1, 1, 1024, 4096): (64, 32),
(32, 8, 4096, 64): (32, 32),
(4, 1, 4096, 64): (32, 32),
(32, 8, 2048, 64): (32, 32),
(4, 1, 2048, 64): (32, 32),
(32, 8, 1024, 64): (32, 32),
(1, 1, 1024, 64): (32, 32),
(32, 8, 4096, 128): (32, 32),
(4, 1, 4096, 128): (32, 32),
(32, 8, 2048, 128): (32, 32),
(4, 1, 2048, 128): (32, 32),
(32, 8, 1024, 128): (32, 32),
(1, 1, 1024, 128): (32, 32),
}


def get_ragged_attention_tuned_block_size(q_head_num, kv_head_num, token_num,
max_model_len):
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
if tpu_version == 4:
# This default block size is not tuned, only make sure there's no
# OOM in vmem
num_kv_pages_per_block = 16
num_queries_per_block = 128
return num_kv_pages_per_block, num_queries_per_block

key = _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num,
max_model_len)
block_sizes = _ragged_attention_table.get(key, (128, 32))
return block_sizes
Loading