11# SPDX-License-Identifier: Apache-2.0
22
3- from typing import Optional , Tuple , Union , final
3+ from typing import TYPE_CHECKING , List , Optional , Tuple , Union , final
44
55import torch
66from vllm_hpu_extension .ops import (dispatch_bgmv_embedding ,
77 dispatch_bgmv_linear )
88
99from .punica_base import PunicaWrapperBase
10+ from .utils import convert_mapping
11+
12+ if TYPE_CHECKING :
13+ # avoid circuit import
14+ from vllm .lora .layers import LoRAMapping
15+ from vllm .lora .models import LongContextLoRAContext
1016
1117
1218@final
@@ -19,6 +25,55 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
1925 PunicaWrapperBase .__init__ (self , 3 * max_num_batched_tokens ,
2026 max_batches , device )
2127
28+ def _update_base_metadata (
29+ self ,
30+ mapping : "LoRAMapping" ,
31+ lora_index_to_id : List [Optional [int ]],
32+ max_loras : int ,
33+ vocab_size : int ,
34+ extra_vocab_size : int ,
35+ long_lora_context : Optional ["LongContextLoRAContext" ] = None ,
36+ ):
37+ (
38+ base_indices ,
39+ sampler_indices ,
40+ sampler_indices_padded ,
41+ embeddings_indices ,
42+ long_lora_offsets_tensor ,
43+ indices_len ,
44+ ) = convert_mapping (mapping , lora_index_to_id , max_loras , vocab_size ,
45+ extra_vocab_size , self .device , None )
46+ # Updating each element in `long_lora_offsets` with `lora_offset` slows
47+ # down perf in HPU due to a series of `strided_insert` ops during lazy
48+ # graph accumulation. Hence HPU appends `lora_offset` to a list and
49+ # converts it to a tensor only after it is ready.
50+ if long_lora_context :
51+ index_mapping_indices : List [int ] = list (
52+ mapping .index_mapping ).copy ()
53+ long_lora_offsets : List [int ] = []
54+ for i in range (len (index_mapping_indices )):
55+ lora_offset : int = long_lora_context .offsets_by_lora_id .get (
56+ index_mapping_indices [i ], 0 )
57+ long_lora_offsets .append (lora_offset )
58+ long_lora_offsets_tensor = torch .tensor (long_lora_offsets ,
59+ device = self .device ,
60+ dtype = torch .long )
61+ indices_len [- 1 ] = long_lora_offsets_tensor .shape [- 1 ]
62+
63+ self ._token_lora_indices [:base_indices .shape [0 ]].copy_ (base_indices )
64+ self ._sampler_indices [:sampler_indices .shape [0 ]].copy_ (sampler_indices )
65+ self ._sampler_indices_padded [:sampler_indices_padded .shape [0 ]].copy_ (
66+ sampler_indices_padded )
67+ self ._embeddings_indices [:embeddings_indices .
68+ shape [0 ], :embeddings_indices .shape [1 ]].copy_ (
69+ embeddings_indices )
70+ if long_lora_offsets_tensor is not None :
71+ self ._long_lora_indices [:long_lora_offsets_tensor .shape [0 ]].copy_ (
72+ long_lora_offsets_tensor )
73+ else :
74+ self ._long_lora_indices .zero_ ()
75+ self .indices_len [:] = indices_len
76+
2277 def add_lora_embedding (self ,
2378 y : torch .Tensor ,
2479 x : torch .Tensor ,
0 commit comments