Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions tpu_commons/models/vllm/vllm_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def load_weights(self):
# the device in load_lora_model is used to set the device used in punica wrapper.
lora_manager, vllm_model = load_lora_model(
vllm_model,
vllm_config_for_load.model_config,
vllm_config_for_load.scheduler_config,
vllm_config_for_load.lora_config,
vllm_config_for_load,
device="jax")
self._register_lora_weights_as_param(
vllm_model, vllm_config_for_load.lora_config)
Expand Down Expand Up @@ -243,8 +241,7 @@ def compute_logits_func(
return compute_logits_func


def load_lora_model(model: torch.nn.Module, model_config: ModelConfig,
scheduler_config: SchedulerConfig, lora_config: LoRAConfig,
def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
device: str) -> torch.nn.Module:
if not supports_lora(model):
raise ValueError(
Expand All @@ -254,19 +251,12 @@ def load_lora_model(model: torch.nn.Module, model_config: ModelConfig,
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")

# Use get_text_config() in case of multimodal models
text_config = model_config.hf_config.get_text_config()

# Add LoRA Manager to the Model Runner
lora_manager = LRUCacheWorkerLoRAManager(
scheduler_config.max_num_seqs,
scheduler_config.max_num_batched_tokens,
model_config.get_vocab_size(),
lora_config,
vllm_config,
device,
model.embedding_modules,
model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
)
return lora_manager, lora_manager.create_lora_manager(model)

Expand Down
Loading