Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,6 @@ def ragged_paged_attention_wrapper(
sliding_window=sliding_window,
soft_cap=soft_cap,
use_kernel=True,
max_model_len=2048,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
):
Expand All @@ -693,7 +692,6 @@ def ragged_paged_attention_wrapper(
sliding_window=sliding_window,
soft_cap=soft_cap,
use_kernel=use_kernel,
max_model_len=max_model_len,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)
Expand All @@ -714,7 +712,6 @@ def ragged_paged_attention_wrapper(
sliding_window=sliding_window,
soft_cap=soft_cap,
use_kernel=True,
max_model_len=2048,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)[:cu_q_lens[num_seqs]]
Expand Down Expand Up @@ -755,14 +752,15 @@ def ragged_paged_attention_wrapper(

from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
max_model_len = 2048
if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num = q.shape[0]
token_num, q_head_num, _ = q.shape
kv_head_num = kv_pages[2] // 2
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
_, pages_per_seq = page_indices.shape
num_kv_heads = num_combined_kv_heads // 2
max_model_len = pages_per_seq * page_size
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
q_head_num, kv_head_num, token_num, max_model_len)
q_head_num, num_kv_heads, token_num, max_model_len)
jax_kernel_output = torch.from_numpy(
np.array(
jax_ragged_paged_attention(
Expand Down
13 changes: 6 additions & 7 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,6 @@ def ragged_paged_attention(
soft_cap: float | None = None,
mask_value=None,
use_kernel=True,
max_model_len=2048, # Used as a hint for the kernel block sizes selection
# kernel tuning parameters
num_kv_pages_per_block=None,
num_queries_per_block=None,
Expand Down Expand Up @@ -960,9 +959,12 @@ def ragged_paged_attention(
if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num, q_head_num, _ = q.shape
kv_head_num = kv_pages[2] // 2
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
_, pages_per_seq = page_indices.shape
num_kv_heads = num_combined_kv_heads // 2
max_model_len = pages_per_seq * page_size
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
q_head_num, kv_head_num, token_num, max_model_len)
q_head_num, num_kv_heads, token_num, max_model_len)

if vmem_limit_bytes is None:
vmem_limit_bytes = 64 * 1024 * 1024
Expand Down Expand Up @@ -1681,7 +1683,7 @@ def non_xla_ragged_paged_attention(q, kv, attention_type):
XLA_LIB.define(
"ragged_paged_attention(Tensor q, Tensor kv_pages, Tensor kv_lens, Tensor page_indices, "
"Tensor cu_q_lens, Tensor num_seqs, float sm_scale=1, int? sliding_window=None, "
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True, int max_model_len=2048,"
"float? soft_cap=None, float? mask_value=None, bool use_kernel=True,"
"int? num_kv_pages_per_block=None, int? num_queries_per_block=None, int? vmem_limit_bytes=None) -> Tensor",
)

Expand All @@ -1699,7 +1701,6 @@ def ragged_paged_attention_xla(
soft_cap: float | None = None,
mask_value=None,
use_kernel=True,
max_model_len=2048,
# kernel tuning parameters
num_kv_pages_per_block=None,
num_queries_per_block=None,
Expand All @@ -1717,7 +1718,6 @@ def ragged_paged_attention_xla(
soft_cap=soft_cap,
mask_value=mask_value,
use_kernel=use_kernel,
max_model_len=max_model_len,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
vmem_limit_bytes=vmem_limit_bytes)
Expand All @@ -1736,7 +1736,6 @@ def ragged_paged_attention_non_xla(
soft_cap: float | None = None,
mask_value=None,
use_kernel=True,
max_model_len=2048,
# kernel tuning parameters
num_kv_pages_per_block=None,
num_queries_per_block=None,
Expand Down
Loading