Skip to content

Commit 2821986

Browse files
authored
[Core] Modify the initialization parameters of the lora manager (vllm-project#25249)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 6c117cf commit 2821986

File tree

10 files changed

+51
-52
lines changed

10 files changed

+51
-52
lines changed

tests/lora/test_lora_manager.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from safetensors.torch import load_file
99
from torch import nn
1010

11+
from vllm.config import ModelConfig, VllmConfig
1112
from vllm.config.lora import LoRAConfig
1213
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
1314
MergedColumnParallelLinearWithLoRA,
1415
RowParallelLinearWithLoRA)
15-
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
1617
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
1718
LRUCacheLoRAModelManager)
1819
from vllm.lora.peft_helper import PEFTHelper
@@ -435,10 +436,19 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
435436
target_modules=["layer1.dense1", "dense2"],
436437
lora_dtype=DEFAULT_DTYPE,
437438
)
439+
440+
model_config = ModelConfig(max_model_len=16)
441+
vllm_config = VllmConfig(model_config=model_config,
442+
lora_config=lora_config)
443+
444+
vllm_config.scheduler_config.max_num_seqs = 4
445+
vllm_config.scheduler_config.max_num_batched_tokens = 2
438446
worker_adapter_manager = LRUCacheWorkerLoRAManager(
439-
4, 2,
440-
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
441-
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
447+
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
448+
449+
worker_adapter_manager.max_num_seqs = 4
450+
worker_adapter_manager.max_num_batched_tokens = 2
451+
442452
worker_adapter_manager.create_lora_manager(dummy_model)
443453

444454
mapping = LoRAMapping([], [])
@@ -517,10 +527,20 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
517527
max_cpu_loras=4,
518528
max_loras=4,
519529
lora_dtype=DEFAULT_DTYPE)
520-
worker_adapter_manager = WorkerLoRAManager(
521-
4, 2, dummy_model_gate_up.unpadded_vocab_size -
522-
lora_config.lora_extra_vocab_size, lora_config, device,
523-
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
530+
531+
model_config = ModelConfig(max_model_len=16)
532+
vllm_config = VllmConfig(model_config=model_config,
533+
lora_config=lora_config)
534+
535+
vllm_config.scheduler_config.max_num_seqs = 4
536+
vllm_config.scheduler_config.max_num_batched_tokens = 2
537+
538+
worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
539+
EMBEDDING_MODULES,
540+
EMBEDDING_PADDING_MODULES)
541+
worker_adapter_manager.vocab_size = (
542+
dummy_model_gate_up.unpadded_vocab_size -
543+
lora_config.lora_extra_vocab_size)
524544
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
525545

526546
dummy_lora_files = f"{tmp_path}/lora_adapter"

tests/lora/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from safetensors.torch import save_file
1111

12-
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
12+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
1313

1414

1515
class DummyLoRAManager:
File renamed without changes.

vllm/lora/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.config.lora import LoRAConfig
1515
from vllm.logger import init_logger
1616
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
17-
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
17+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
1818
from vllm.lora.peft_helper import PEFTHelper
1919
from vllm.lora.punica_wrapper import get_punica_wrapper
2020
from vllm.lora.utils import (from_layer, from_layer_logits_processor,

vllm/lora/worker_manager.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from vllm.config.lora import LoRAConfig
9+
from vllm.config import VllmConfig
1010
from vllm.logger import init_logger
1111
from vllm.lora.models import (LoRAModel, LoRAModelManager,
1212
LRUCacheLoRAModelManager, create_lora_manager)
@@ -27,25 +27,26 @@ class WorkerLoRAManager:
2727

2828
def __init__(
2929
self,
30-
max_num_seqs: int,
31-
max_num_batched_tokens: int,
32-
vocab_size: int,
33-
lora_config: LoRAConfig,
30+
vllm_config: VllmConfig,
3431
device: torch.device,
3532
embedding_modules: dict[str, str],
3633
embedding_padding_modules: list[str],
3734
lora_model_cls: type[LoRAModel] = LoRAModel,
38-
max_position_embeddings: Optional[int] = None,
3935
):
4036
self._lora_model_cls = lora_model_cls
4137
self.embedding_modules = embedding_modules
4238
self.embedding_padding_modules = embedding_padding_modules
4339
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
44-
self.max_num_seqs = max_num_seqs
45-
self.max_num_batched_tokens = max_num_batched_tokens
46-
self.vocab_size = vocab_size
47-
self.lora_config = lora_config
48-
self.max_position_embeddings = max_position_embeddings
40+
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
41+
self.max_num_batched_tokens = (
42+
vllm_config.scheduler_config.max_num_batched_tokens)
43+
self.vocab_size = vllm_config.model_config.get_vocab_size()
44+
self.lora_config = vllm_config.lora_config
45+
46+
# Use get_text_config() in case of multimodal models
47+
text_config = vllm_config.model_config.hf_config.get_text_config()
48+
49+
self.max_position_embeddings = text_config.max_position_embeddings
4950
self.device = device
5051
# Lazily initialized by create_lora_manager.
5152
self._adapter_manager: LoRAModelManager

vllm/v1/worker/cpu_model_runner.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,8 @@ def load_model(self, eep_scale_up: bool = False) -> None:
107107
self.model = get_model(vllm_config=self.vllm_config)
108108

109109
if self.lora_config:
110-
self.model = self.load_lora_model(self.model, self.model_config,
111-
self.scheduler_config,
112-
self.lora_config, self.device)
110+
self.model = self.load_lora_model(self.model, self.vllm_config,
111+
self.device)
113112

114113
def get_model(self) -> nn.Module:
115114
return self.model

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2552,10 +2552,7 @@ def load_model(self, eep_scale_up: bool = False) -> None:
25522552
self.model = model_loader.load_model(
25532553
vllm_config=self.vllm_config, model_config=self.model_config)
25542554
if self.lora_config:
2555-
self.model = self.load_lora_model(self.model,
2556-
self.model_config,
2557-
self.scheduler_config,
2558-
self.lora_config,
2555+
self.model = self.load_lora_model(self.model, self.vllm_config,
25592556
self.device)
25602557
if hasattr(self, "drafter"):
25612558
logger.info("Loading drafter model...")

vllm/v1/worker/lora_model_runner_mixin.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313

14-
from vllm.config import ModelConfig, SchedulerConfig
14+
from vllm.config import VllmConfig
1515
from vllm.config.lora import LoRAConfig
1616
from vllm.logger import init_logger
1717
from vllm.lora.layers import LoRAMapping
@@ -31,9 +31,7 @@ class LoRAModelRunnerMixin:
3131

3232
LORA_WARMUP_RANK = 8
3333

34-
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
35-
scheduler_config: SchedulerConfig,
36-
lora_config: LoRAConfig,
34+
def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig,
3735
device: torch.device) -> nn.Module:
3836

3937
if not supports_lora(model):
@@ -44,19 +42,12 @@ def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
4442
logger.warning("Regarding multimodal models, vLLM currently "
4543
"only supports adding LoRA to language model.")
4644

47-
# Use get_text_config() in case of multimodal models
48-
text_config = model_config.hf_config.get_text_config()
49-
5045
# Add LoRA Manager to the Model Runner
5146
self.lora_manager = LRUCacheWorkerLoRAManager(
52-
scheduler_config.max_num_seqs,
53-
scheduler_config.max_num_batched_tokens,
54-
model_config.get_vocab_size(),
55-
lora_config,
47+
vllm_config,
5648
device,
5749
model.embedding_modules,
5850
model.embedding_padding_modules,
59-
max_position_embeddings=text_config.max_position_embeddings,
6051
)
6152
return self.lora_manager.create_lora_manager(model)
6253

vllm/v1/worker/tpu_model_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,9 +1178,7 @@ def load_model(self) -> None:
11781178
"or sharding the weights on more chips. "
11791179
f"See the detailed error: {e}") from e
11801180
if self.lora_config is not None:
1181-
model = self.load_lora_model(model, self.model_config,
1182-
self.scheduler_config,
1183-
self.lora_config, self.device)
1181+
model = self.load_lora_model(model, self.vllm_config, self.device)
11841182
replace_set_lora(model)
11851183

11861184
# Sync all pending XLA execution during model initialization and weight

vllm/worker/model_runner.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,20 +1078,13 @@ def load_model(self) -> None:
10781078
"Regarding multimodal models, vLLM currently "
10791079
"only supports adding LoRA to language model.")
10801080

1081-
# Use get_text_config() in case of multimodal models
1082-
text_config = self.model_config.hf_config.get_text_config()
1083-
10841081
self.lora_manager = LRUCacheWorkerLoRAManager(
1085-
self.scheduler_config.max_num_seqs,
1086-
self.scheduler_config.max_num_batched_tokens,
1087-
self.vocab_size,
1088-
self.lora_config,
1082+
self.vllm_config,
10891083
self.device,
10901084
self.model.embedding_modules,
10911085
self.model.embedding_padding_modules,
1092-
max_position_embeddings=text_config.
1093-
max_position_embeddings,
10941086
)
1087+
10951088
self.model = self.lora_manager.create_lora_manager(self.model)
10961089
time_after_load = time.perf_counter()
10971090

0 commit comments

Comments
 (0)