Skip to content

Commit abf3db4

Browse files
authored
[Core] Handle MoE LoRA edge cases (vllm-project#27335)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 8e4ca4d commit abf3db4

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

vllm/lora/layers/fused_moe.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def wrapper(*args, **kwargs):
7474
moe_state_dict["apply_router_weight_on_input"] = kwargs[
7575
"apply_router_weight_on_input"
7676
]
77-
moe_state_dict["max_loras"] = layer.w1_lora_a_stacked.shape[0]
7877
result = func(*args, **kwargs)
7978
return result
8079

@@ -89,7 +88,6 @@ def wrapper(*args, **kwargs):
8988
curr_topk_ids = moe_state_dict["topk_ids"]
9089
global_num_experts = moe_state_dict["global_num_experts"]
9190
expert_map = moe_state_dict["expert_map"]
92-
max_loras = moe_state_dict["max_loras"]
9391

9492
config_dtype = _get_config_dtype_str(
9593
dtype=hidden_states.dtype,
@@ -110,6 +108,7 @@ def wrapper(*args, **kwargs):
110108
block_shape=layer.quant_method.moe_quant_config.block_shape,
111109
)
112110

111+
max_loras = self.w1_lora_a_stacked.shape[0]
113112
config = get_config_func(M)
114113
(
115114
sorted_token_ids_lora,
@@ -161,7 +160,6 @@ def moe_sum_decorator(layer, func):
161160
def wrapper(*args, **kwargs):
162161
hidden_states = moe_state_dict["hidden_states"]
163162
topk_weights = moe_state_dict["topk_weights"]
164-
max_loras = moe_state_dict["max_loras"]
165163

166164
config_dtype = _get_config_dtype_str(
167165
dtype=hidden_states.dtype,
@@ -189,7 +187,7 @@ def wrapper(*args, **kwargs):
189187
num_tokens_post_padded_lora = moe_state_dict[
190188
"num_tokens_post_padded_lora"
191189
]
192-
190+
max_loras = self.w1_lora_a_stacked.shape[0]
193191
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
194192
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
195193
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
@@ -305,12 +303,6 @@ def create_lora_weights(
305303
device=self.device,
306304
)
307305

308-
self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
309-
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
310-
self.base_layer.w2_lora_a_stacked = self.w2_lora_a_stacked
311-
self.base_layer.w2_lora_b_stacked = self.w2_lora_b_stacked
312-
self.base_layer.w3_lora_a_stacked = self.w3_lora_a_stacked
313-
self.base_layer.w3_lora_b_stacked = self.w3_lora_b_stacked
314306
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
315307
# to create a dummy LoRA weights.
316308
self.lora_a_stacked = []
@@ -343,6 +335,7 @@ def set_lora(
343335
embeddings_tensor: torch.Tensor | None,
344336
bias: torch.Tensor | None = None,
345337
):
338+
self.reset_lora(index)
346339
"""Overwrites lora tensors at index."""
347340
for eid in range(len(lora_a) // 3):
348341
w1_lora_a = lora_a[eid * 3]
@@ -352,6 +345,10 @@ def set_lora(
352345
w2_lora_b = lora_b[eid * 3 + 1]
353346
w3_lora_b = lora_b[eid * 3 + 2]
354347

348+
# Handle the case of adding LoRA to only a subset of experts
349+
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
350+
continue
351+
355352
if self.tp_size > 1:
356353
shard_size = self.base_layer.intermediate_size_per_partition
357354
start_idx = self.tp_rank * shard_size

vllm/lora/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ def activate_adapter(
426426
for module_name, module in self.modules.items():
427427
module_lora = self._get_lora_layer_weights(lora_model, module_name)
428428
if module_lora:
429-
module_lora.optimize()
430429
# Note (gnovack) - If MOE lora weights are not split into
431430
# num_experts chunks, we split them here
432431
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(

0 commit comments

Comments
 (0)