Skip to content

Commit 5cd58d1

Browse files
Enable long-contexts + LoRA support for Intel Gaudi
Signed-off-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
1 parent 6e1fc61 commit 5cd58d1

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

vllm/lora/punica_wrapper/utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,18 @@ def convert_mapping(
8888
embedding_indices = index_mapping_indices.copy()
8989
lora_indices = index_mapping_indices.copy()
9090
long_lora_offsets: Optional[torch.Tensor] = None
91+
92+
# Updating each element in `long_lora_offsets` with `lora_offset` slows
93+
# down perf in HPU due to a series of `strided_insert` ops during lazy
94+
# graph accumulation. Hence HPU appends `lora_offset` to a list and
95+
# converts it to a tensor only after it is ready.
9196
if long_lora_context:
92-
long_lora_offsets = torch.zeros(len(index_mapping_indices),
93-
device=device,
94-
dtype=torch.long)
97+
if device == torch.device("hpu"):
98+
long_lora_offsets_list: List[int] = []
99+
else:
100+
long_lora_offsets = torch.zeros(len(index_mapping_indices),
101+
device=device,
102+
dtype=torch.long)
95103
prompt_mapping: List[int] = [
96104
lora_index_to_id.index(x) if x > 0 else -1
97105
for x in mapping.prompt_mapping
@@ -104,10 +112,18 @@ def convert_mapping(
104112
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
105113
lora_indices[i] = lora_idx
106114
if long_lora_context:
107-
assert long_lora_offsets is not None
108115
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
109116
index_mapping_indices[i], 0)
110-
long_lora_offsets[i] = lora_offset
117+
if device == torch.device("hpu"):
118+
long_lora_offsets_list.append(lora_offset)
119+
else:
120+
assert long_lora_offsets is not None
121+
long_lora_offsets[i] = lora_offset
122+
123+
if long_lora_context and device == torch.device("hpu"):
124+
long_lora_offsets = torch.tensor(long_lora_offsets_list,
125+
device=device,
126+
dtype=torch.long)
111127

112128
indices_list: List[Union[List[int], torch.Tensor]] = [
113129
index_mapping_indices,

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,10 @@ def forward_hpu(
206206
) -> Tuple[torch.Tensor, torch.Tensor]:
207207
from habana_frameworks.torch.hpex.kernels import (
208208
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
209-
positions = positions.flatten()
210209
if offsets is not None:
210+
offsets = offsets.view(positions.shape[0], -1)
211211
positions = positions + offsets
212+
positions = positions.flatten()
212213
num_tokens = positions.shape[0]
213214
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
214215
num_tokens, 1, -1)

vllm/worker/hpu_model_runner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,12 +639,25 @@ def load_model(self) -> None:
639639
"Bias support in LoRA is not enabled in HPU yet."
640640
assert not self.lora_config.fully_sharded_loras, \
641641
"Fully sharded LoRAs is not enabled in HPU yet."
642+
# It's necessary to distinguish between the
643+
# max_position_embeddings of VLMs and LLMs.
644+
if hasattr(self.model.config, "max_position_embeddings"):
645+
max_pos_embeddings = (
646+
self.model.config.max_position_embeddings)
647+
else:
648+
max_pos_embeddings = (
649+
self.model.config.text_config.max_position_embeddings)
650+
642651
self.lora_manager = LRUCacheWorkerLoRAManager(
643652
self.scheduler_config.max_num_seqs,
644653
self.scheduler_config.max_num_batched_tokens,
645-
self.vocab_size, self.lora_config, self.device,
654+
self.vocab_size,
655+
self.lora_config,
656+
self.device,
646657
self.model.embedding_modules,
647-
self.model.embedding_padding_modules)
658+
self.model.embedding_padding_modules,
659+
max_position_embeddings=max_pos_embeddings,
660+
)
648661
self.model = self.lora_manager.create_lora_manager(self.model)
649662

650663
if self.model_config.quantization == 'inc':

0 commit comments

Comments
 (0)