Skip to content

Commit 9ffe905

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Bugfix][Model] Fix LoRA for Mistral-Small-3.1-24B-Instruct-2503 (vllm-project#21183)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent 9a9fda1 commit 9ffe905

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

vllm/lora/models.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,14 @@ def remove_all_adapters(self):
498498
self._active_adapters.clear()
499499

500500
def _create_lora_modules(self):
501+
502+
def _parent_module(module_name: str) -> str:
503+
# module name is a dot separated name.
504+
# for example:
505+
# - given an input 'x.y.z' return 'x.y'
506+
# - given an input 'x' return ''
507+
return module_name.rpartition('.')[0]
508+
501509
for module_name, module in self.model.named_modules(
502510
remove_duplicate=False):
503511
if isinstance(module, PPMissingLayer):
@@ -529,10 +537,17 @@ def _create_lora_modules(self):
529537
new_module.scaling_factor_to_offset
530538
# (yard1): TODO make this more robust
531539
if "lm_head" in module_name:
540+
logits_processor_module_name = 'logits_processor'
541+
parent_module = _parent_module(module_name)
542+
if parent_module:
543+
logits_processor_module_name = (
544+
f"{parent_module}.{logits_processor_module_name}")
545+
532546
logits_processor_module = self.model.get_submodule(
533-
"logits_processor")
547+
logits_processor_module_name)
548+
534549
new_module = replace_submodule(
535-
self.model, "logits_processor",
550+
self.model, logits_processor_module_name,
536551
from_layer_logits_processor(logits_processor_module,
537552
module, self.lora_slots,
538553
self.lora_config,

vllm/lora/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,20 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]:
188188
"""
189189
In vLLM, all linear layers support LoRA.
190190
"""
191+
191192
supported_lora_modules: set[str] = set()
192-
# step1: traverse the model to get all the linear subfixes.
193193
for name, module in model.named_modules():
194+
# get the embedding modules if the module's embedding_modules
195+
# is not empty.
196+
embedding_modules = getattr(module, "embedding_modules", None)
197+
if embedding_modules is not None:
198+
for name in embedding_modules:
199+
supported_lora_modules.add(name)
200+
201+
# get all the linear subfixes.
194202
if isinstance(module, (LinearBase, )):
195203
supported_lora_modules.add(name.split(".")[-1])
196-
# step 2: get the embedding modules if the model's mbedding_modules
197-
# is not empty.
198-
if model.embedding_modules:
199-
for name in model.embedding_modules:
200-
supported_lora_modules.add(name)
204+
201205
return list(supported_lora_modules)
202206

203207

0 commit comments

Comments
 (0)