@@ -88,10 +88,18 @@ def convert_mapping(
8888 embedding_indices = index_mapping_indices .copy ()
8989 lora_indices = index_mapping_indices .copy ()
9090 long_lora_offsets : Optional [torch .Tensor ] = None
91+
92+ # Updating each element in `long_lora_offsets` with `lora_offset` slows
93+ # down perf in HPU due to a series of `strided_insert` ops during lazy
94+ # graph accumulation. Hence HPU appends `lora_offset` to a list and
95+ # converts it to a tensor only after it is ready.
9196 if long_lora_context :
92- long_lora_offsets = torch .zeros (len (index_mapping_indices ),
93- device = device ,
94- dtype = torch .long )
97+ if device == torch .device ("hpu" ):
98+ long_lora_offsets_list : List [int ] = []
99+ else :
100+ long_lora_offsets = torch .zeros (len (index_mapping_indices ),
101+ device = device ,
102+ dtype = torch .long )
95103 prompt_mapping : List [int ] = [
96104 lora_index_to_id .index (x ) if x > 0 else - 1
97105 for x in mapping .prompt_mapping
@@ -104,10 +112,18 @@ def convert_mapping(
104112 embedding_indices [i ] = lora_idx if index_mapping_indices [i ] > 0 else 0
105113 lora_indices [i ] = lora_idx
106114 if long_lora_context :
107- assert long_lora_offsets is not None
108115 lora_offset : int = long_lora_context .offsets_by_lora_id .get (
109116 index_mapping_indices [i ], 0 )
110- long_lora_offsets [i ] = lora_offset
117+ if device == torch .device ("hpu" ):
118+ long_lora_offsets_list .append (lora_offset )
119+ else :
120+ assert long_lora_offsets is not None
121+ long_lora_offsets [i ] = lora_offset
122+
123+ if long_lora_context and device == torch .device ("hpu" ):
124+ long_lora_offsets = torch .tensor (long_lora_offsets_list ,
125+ device = device ,
126+ dtype = torch .long )
111127
112128 indices_list : List [Union [List [int ], torch .Tensor ]] = [
113129 index_mapping_indices ,
0 commit comments