@@ -74,7 +74,6 @@ def wrapper(*args, **kwargs):
7474 moe_state_dict ["apply_router_weight_on_input" ] = kwargs [
7575 "apply_router_weight_on_input"
7676 ]
77- moe_state_dict ["max_loras" ] = layer .w1_lora_a_stacked .shape [0 ]
7877 result = func (* args , ** kwargs )
7978 return result
8079
@@ -89,7 +88,6 @@ def wrapper(*args, **kwargs):
8988 curr_topk_ids = moe_state_dict ["topk_ids" ]
9089 global_num_experts = moe_state_dict ["global_num_experts" ]
9190 expert_map = moe_state_dict ["expert_map" ]
92- max_loras = moe_state_dict ["max_loras" ]
9391
9492 config_dtype = _get_config_dtype_str (
9593 dtype = hidden_states .dtype ,
@@ -110,6 +108,7 @@ def wrapper(*args, **kwargs):
110108 block_shape = layer .quant_method .moe_quant_config .block_shape ,
111109 )
112110
111+ max_loras = self .w1_lora_a_stacked .shape [0 ]
113112 config = get_config_func (M )
114113 (
115114 sorted_token_ids_lora ,
@@ -161,7 +160,6 @@ def moe_sum_decorator(layer, func):
161160 def wrapper (* args , ** kwargs ):
162161 hidden_states = moe_state_dict ["hidden_states" ]
163162 topk_weights = moe_state_dict ["topk_weights" ]
164- max_loras = moe_state_dict ["max_loras" ]
165163
166164 config_dtype = _get_config_dtype_str (
167165 dtype = hidden_states .dtype ,
@@ -189,7 +187,7 @@ def wrapper(*args, **kwargs):
189187 num_tokens_post_padded_lora = moe_state_dict [
190188 "num_tokens_post_padded_lora"
191189 ]
192-
190+ max_loras = self . w1_lora_a_stacked . shape [ 0 ]
193191 expert_ids_lora = expert_ids_lora .view (max_loras , - 1 )
194192 sorted_token_ids_lora = sorted_token_ids_lora .view (max_loras , - 1 )
195193 intermediate_cache2 = moe_state_dict ["intermediate_cache2" ]
@@ -305,12 +303,6 @@ def create_lora_weights(
305303 device = self .device ,
306304 )
307305
308- self .base_layer .w1_lora_a_stacked = self .w1_lora_a_stacked
309- self .base_layer .w1_lora_b_stacked = self .w1_lora_b_stacked
310- self .base_layer .w2_lora_a_stacked = self .w2_lora_a_stacked
311- self .base_layer .w2_lora_b_stacked = self .w2_lora_b_stacked
312- self .base_layer .w3_lora_a_stacked = self .w3_lora_a_stacked
313- self .base_layer .w3_lora_b_stacked = self .w3_lora_b_stacked
314306 # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
315307 # to create a dummy LoRA weights.
316308 self .lora_a_stacked = []
@@ -343,6 +335,7 @@ def set_lora(
343335 embeddings_tensor : torch .Tensor | None ,
344336 bias : torch .Tensor | None = None ,
345337 ):
338+ self .reset_lora (index )
346339 """Overwrites lora tensors at index."""
347340 for eid in range (len (lora_a ) // 3 ):
348341 w1_lora_a = lora_a [eid * 3 ]
@@ -352,6 +345,10 @@ def set_lora(
352345 w2_lora_b = lora_b [eid * 3 + 1 ]
353346 w3_lora_b = lora_b [eid * 3 + 2 ]
354347
348+ # Handle the case of adding LoRA to only a subset of experts
349+ if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None :
350+ continue
351+
355352 if self .tp_size > 1 :
356353 shard_size = self .base_layer .intermediate_size_per_partition
357354 start_idx = self .tp_rank * shard_size
0 commit comments