Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions vllm/attention/ops/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def aiter_mla_decode_fwd(
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
Expand All @@ -45,6 +53,14 @@ def aiter_mla_decode_fwd(
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
work_meta_data=work_meta_data,
work_indptr=work_indptr,
work_info_set=work_info_set,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
q_scale=q_scale,
kv_scale=kv_scale,
)


Expand All @@ -59,6 +75,14 @@ def mla_decode_fwd_impl(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None,
) -> None:
from aiter.mla import mla_decode_fwd

Expand All @@ -73,6 +97,14 @@ def mla_decode_fwd_impl(
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
work_meta_data=work_meta_data,
work_indptr=work_indptr,
work_info_set=work_info_set,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
q_scale=q_scale,
kv_scale=kv_scale,
)


Expand All @@ -87,6 +119,14 @@ def mla_decode_fwd_fake(
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None,
) -> None:
pass

Expand Down
103 changes: 96 additions & 7 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from dataclasses import dataclass
from typing import ClassVar

Expand All @@ -17,6 +18,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -56,6 +58,20 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The query indptr, shape : [num_decode + 1]
qo_indptr: torch.Tensor | None = None

max_seqlen_qo: int = 1

work_metadata: torch.Tensor | None = None

work_info_set: torch.Tensor | None = None

work_indptr: torch.Tensor | None = None

reduce_indptr: torch.Tensor | None = None

reduce_final_map: torch.Tensor | None = None

reduce_partial_map: torch.Tensor | None = None


class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
Expand All @@ -64,9 +80,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN

def __init__(
self,
Expand All @@ -82,13 +97,47 @@ def __init__(
"AITER MLAonly supports block size 1."
)

gpu = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(gpu)
cu_num = device_properties.multi_processor_count

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req

# num_mtp = vllm_config.speculative_config.num_speculative_tokens
# num_mtp = 1 if num_mtp is None else num_mtp
max_seqlen_qo = (
1
if vllm_config.speculative_config is None
else vllm_config.speculative_config.num_speculative_tokens
)

max_qo_tiles_per_batch = int(math.ceil(max_seqlen_qo * self.num_heads / 128))
self.work_metadata = torch.empty([10], dtype=torch.uint64, device="cuda")
self.work_indptr = torch.empty([cu_num + 1], dtype=torch.int32, device="cuda")
self.work_info_set = torch.empty(
[max_num_reqs * max_qo_tiles_per_batch * cu_num, 8],
dtype=torch.int32,
device="cuda",
).fill_(-1)
self.reduce_indptr = torch.empty(
[max_num_reqs * max_qo_tiles_per_batch + 1],
dtype=torch.int32,
device="cuda",
)
self.reduce_final_map = torch.empty(
[max_num_reqs * max_qo_tiles_per_batch, 2], dtype=torch.int32, device="cuda"
)
self.reduce_partial_map = torch.empty(
[max_num_reqs * max_qo_tiles_per_batch * cu_num],
dtype=torch.int32,
device="cuda",
)

# Preparing persistent buffers
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
Expand Down Expand Up @@ -139,6 +188,32 @@ def _build_decode(
block_table_bounds.cumsum(dim=0, dtype=torch.int32),
]
)
kv_indptr = torch.zeros(
[query_start_loc_cpu.size(0)], dtype=torch.int32, device="cuda"
)
torch.cumsum(seq_lens_device, dim=0, out=kv_indptr[1:])
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_seqlen_qo = torch.max(query_lens).item()

import aiter

aiter.get_mla_metadata_v1(
query_start_loc_device,
kv_indptr,
self.num_heads // self.kv_cache_spec.num_kv_heads,
self.kv_cache_spec.num_kv_heads,
True,
self.work_metadata,
self.work_info_set,
self.work_indptr,
self.reduce_indptr,
self.reduce_final_map,
self.reduce_partial_map,
kv_granularity=max(self.kv_cache_spec.block_size, 16),
max_seqlen_qo=max_seqlen_qo,
uni_seqlen_qo=max_seqlen_qo,
fast_mode=True,
)

if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0)
Expand Down Expand Up @@ -176,6 +251,13 @@ def _build_decode(
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_seqlen_qo=max_seqlen_qo,
work_metadata=self.work_metadata,
work_info_set=self.work_info_set,
work_indptr=self.work_indptr,
reduce_indptr=self.reduce_indptr,
reduce_final_map=self.reduce_final_map,
reduce_partial_map=self.reduce_partial_map,
)

return attn_metadata
Expand Down Expand Up @@ -256,24 +338,31 @@ def _forward_decode(
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(
B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
)
B, self.num_heads, self.kv_lora_rank, dtype=torch.bfloat16, device=q.device
).fill_(-1)

kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(
q,
kv_buffer,
o,
self.scale,
attn_metadata.decode.qo_indptr,
max_seqlen_qo,
attn_metadata.decode.max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len,
work_meta_data=attn_metadata.decode.work_metadata,
work_indptr=attn_metadata.decode.work_indptr,
work_info_set=attn_metadata.decode.work_info_set,
reduce_indptr=attn_metadata.decode.reduce_indptr,
reduce_final_map=attn_metadata.decode.reduce_final_map,
reduce_partial_map=attn_metadata.decode.reduce_partial_map,
q_scale=layer._q_scale,
kv_scale=layer._k_scale,
)

return o, None