Skip to content

Commit 25dd63d

Browse files
Handle long-context + lora explicitly after convert_mapping
Signed-off-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
1 parent 5cd58d1 commit 25dd63d

File tree

2 files changed

+61
-22
lines changed

2 files changed

+61
-22
lines changed

vllm/lora/punica_wrapper/punica_hpu.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Optional, Tuple, Union, final
3+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
44

55
import torch
66
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
77
dispatch_bgmv_linear)
88

99
from .punica_base import PunicaWrapperBase
10+
from .utils import convert_mapping
11+
12+
if TYPE_CHECKING:
13+
# avoid circuit import
14+
from vllm.lora.layers import LoRAMapping
15+
from vllm.lora.models import LongContextLoRAContext
1016

1117

1218
@final
@@ -19,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
1925
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
2026
max_batches, device)
2127

28+
def _update_base_metadata(
29+
self,
30+
mapping: "LoRAMapping",
31+
lora_index_to_id: List[Optional[int]],
32+
max_loras: int,
33+
vocab_size: int,
34+
extra_vocab_size: int,
35+
long_lora_context: Optional["LongContextLoRAContext"] = None,
36+
):
37+
(
38+
base_indices,
39+
sampler_indices,
40+
sampler_indices_padded,
41+
embeddings_indices,
42+
long_lora_offsets_tensor,
43+
indices_len,
44+
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
45+
extra_vocab_size, self.device, None)
46+
# Updating each element in `long_lora_offsets` with `lora_offset` slows
47+
# down perf in HPU due to a series of `strided_insert` ops during lazy
48+
# graph accumulation. Hence HPU appends `lora_offset` to a list and
49+
# converts it to a tensor only after it is ready.
50+
if long_lora_context:
51+
index_mapping_indices: List[int] = list(
52+
mapping.index_mapping).copy()
53+
long_lora_offsets: List[int] = []
54+
for i in range(len(index_mapping_indices)):
55+
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
56+
index_mapping_indices[i], 0)
57+
long_lora_offsets.append(lora_offset)
58+
long_lora_offsets_tensor = torch.tensor(long_lora_offsets,
59+
device=self.device,
60+
dtype=torch.long)
61+
indices_len[-1] = long_lora_offsets_tensor.shape[-1]
62+
63+
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
64+
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
65+
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
66+
sampler_indices_padded)
67+
self._embeddings_indices[:embeddings_indices.
68+
shape[0], :embeddings_indices.shape[1]].copy_(
69+
embeddings_indices)
70+
if long_lora_offsets_tensor is not None:
71+
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
72+
long_lora_offsets_tensor)
73+
else:
74+
self._long_lora_indices.zero_()
75+
self.indices_len[:] = indices_len
76+
2277
def add_lora_embedding(self,
2378
y: torch.Tensor,
2479
x: torch.Tensor,

vllm/lora/punica_wrapper/utils.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,10 @@ def convert_mapping(
8888
embedding_indices = index_mapping_indices.copy()
8989
lora_indices = index_mapping_indices.copy()
9090
long_lora_offsets: Optional[torch.Tensor] = None
91-
92-
# Updating each element in `long_lora_offsets` with `lora_offset` slows
93-
# down perf in HPU due to a series of `strided_insert` ops during lazy
94-
# graph accumulation. Hence HPU appends `lora_offset` to a list and
95-
# converts it to a tensor only after it is ready.
9691
if long_lora_context:
97-
if device == torch.device("hpu"):
98-
long_lora_offsets_list: List[int] = []
99-
else:
100-
long_lora_offsets = torch.zeros(len(index_mapping_indices),
101-
device=device,
102-
dtype=torch.long)
92+
long_lora_offsets = torch.zeros(len(index_mapping_indices),
93+
device=device,
94+
dtype=torch.long)
10395
prompt_mapping: List[int] = [
10496
lora_index_to_id.index(x) if x > 0 else -1
10597
for x in mapping.prompt_mapping
@@ -112,18 +104,10 @@ def convert_mapping(
112104
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
113105
lora_indices[i] = lora_idx
114106
if long_lora_context:
107+
assert long_lora_offsets is not None
115108
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
116109
index_mapping_indices[i], 0)
117-
if device == torch.device("hpu"):
118-
long_lora_offsets_list.append(lora_offset)
119-
else:
120-
assert long_lora_offsets is not None
121-
long_lora_offsets[i] = lora_offset
122-
123-
if long_lora_context and device == torch.device("hpu"):
124-
long_lora_offsets = torch.tensor(long_lora_offsets_list,
125-
device=device,
126-
dtype=torch.long)
110+
long_lora_offsets[i] = lora_offset
127111

128112
indices_list: List[Union[List[int], torch.Tensor]] = [
129113
index_mapping_indices,

0 commit comments

Comments
 (0)