Skip to content

Commit 0ea80c8

Browse files
[Model] Define merge_by_field_config MM interface (vllm-project#25676)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent b8d9e4a commit 0ea80c8

File tree

5 files changed

+44
-12
lines changed

5 files changed

+44
-12
lines changed

tests/models/multimodal/processing/test_tensor_schema.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
init_distributed_environment,
2020
initialize_model_parallel)
2121
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
22+
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
23+
supports_multimodal)
2224
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
2325
from vllm.multimodal.processing import (BaseMultiModalProcessor,
2426
InputProcessingContext)
@@ -88,6 +90,7 @@ def resize_mm_data(
8890

8991

9092
def create_batched_mm_kwargs(
93+
model_cls: type[SupportsMultiModal],
9194
model_config: ModelConfig,
9295
processor: BaseMultiModalProcessor,
9396
size_factors: tuple[float, ...] = (1.0, 0.5, 0.25),
@@ -127,16 +130,22 @@ def create_batched_mm_kwargs(
127130
mm_data=resized_mm_data,
128131
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
129132
tokenization_kwargs=processor_inputs.tokenization_kwargs,
130-
)["mm_kwargs"]
133+
)["mm_kwargs"].require_data()
131134
items = [
132135
item for modality in supported_mm_limits
133136
for item in mm_kwargs[modality]
134137
]
135-
return group_mm_kwargs_by_modality(items)
138+
return group_mm_kwargs_by_modality(
139+
items,
140+
merge_by_field_config=model_cls.merge_by_field_config,
141+
)
136142

137143

138144
@contextmanager
139-
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
145+
def initialize_dummy_model(
146+
model_cls: type[nn.Module],
147+
model_config: ModelConfig,
148+
):
140149
temp_file = tempfile.mkstemp()[1]
141150
init_distributed_environment(
142151
world_size=1,
@@ -198,8 +207,12 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
198207
hf_overrides=hf_overrides_fn,
199208
skip_tokenizer_init=model_info.skip_tokenizer_init,
200209
enforce_eager=model_info.enforce_eager,
201-
dtype=model_info.dtype)
210+
dtype=model_info.dtype,
211+
)
212+
202213
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
214+
assert supports_multimodal(model_cls)
215+
203216
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
204217

205218
inputs_parse_methods = []
@@ -228,7 +241,7 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
228241

229242
with initialize_dummy_model(model_cls, model_config) as model:
230243
for modality, _, mm_kwargs in create_batched_mm_kwargs(
231-
model_config, processor):
244+
model_cls, model_config, processor):
232245
for method_name in inputs_parse_methods:
233246
print(f"Testing `{method_name}` with modality={modality} "
234247
f"and mm_kwargs{list(mm_kwargs.keys())}")

vllm/config/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@
6363
ConvertOption = Literal["auto", ConvertType]
6464
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
6565
"score", "reward", "transcription", "draft"]
66-
_ResolvedTask = Literal["generate", "transcription", "encode", "embed",
67-
"classify", "reward", "draft"]
6866
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
6967
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
7068
LogprobsMode = Literal["raw_logits", "raw_logprobs", "processed_logits",
7169
"processed_logprobs"]
72-
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
70+
HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
71+
PretrainedConfig]]
7372
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
7473

7574
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {

vllm/model_executor/models/interfaces.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ class SupportsMultiModal(Protocol):
6464
`multimodal_config.mm_encoder_tp_mode="data"`.
6565
"""
6666

67+
merge_by_field_config: ClassVar[bool] = False
68+
"""
69+
A flag that indicates which implementation of
70+
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
71+
"""
72+
6773
@classmethod
6874
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
6975
"""

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
from vllm.model_executor.layers.mamba.abstract import MambaBase
4141
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
4242
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
43-
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
43+
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
44+
is_mixture_of_experts,
4445
supports_eagle3,
4546
supports_mrope,
4647
supports_transcription)
@@ -777,11 +778,13 @@ def _extract_mm_kwargs(
777778
mm_kwargs.append(feature.data)
778779

779780
# Input all modalities at once
781+
model = cast(SupportsMultiModal, self.model)
780782
mm_kwargs_combined: BatchedTensorInputs = {}
781783
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
782784
mm_kwargs,
783785
device=self.device,
784786
pin_memory=self.pin_memory,
787+
merge_by_field_config=model.merge_by_field_config,
785788
):
786789
mm_kwargs_combined.update(mm_kwargs_group)
787790

@@ -1525,11 +1528,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
15251528
# in the same batch while still being able to benefit from batching
15261529
# multimodal inputs. The proper solution should be reordering the
15271530
# encoder outputs.
1531+
model = cast(SupportsMultiModal, self.model)
15281532
encoder_outputs = []
15291533
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
15301534
mm_kwargs,
15311535
device=self.device,
15321536
pin_memory=self.pin_memory,
1537+
merge_by_field_config=model.merge_by_field_config,
15331538
):
15341539
# Run the encoder.
15351540
# `curr_group_outputs` is either of the following:
@@ -1538,7 +1543,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
15381543
# 2. A list or tuple (length: num_items) of tensors, each of shape
15391544
# (feature_size, hidden_size) in case the feature size is dynamic
15401545
# depending on the input multimodal items.
1541-
curr_group_outputs = self.model.get_multimodal_embeddings(
1546+
curr_group_outputs = model.get_multimodal_embeddings(
15421547
**mm_kwargs_group)
15431548

15441549
sanity_check_mm_encoder_outputs(
@@ -1623,11 +1628,13 @@ def _extract_encoder_inputs(
16231628
return {}
16241629

16251630
# Group MM kwargs by modality and extract features
1631+
model = cast(SupportsMultiModal, self.model)
16261632
encoder_features = {}
16271633
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
16281634
mm_kwargs,
16291635
device=self.device,
16301636
pin_memory=self.pin_memory,
1637+
merge_by_field_config=model.merge_by_field_config,
16311638
):
16321639
# Add the grouped features to encoder_features dict
16331640
# This allows the model to receive them as kwargs (e.g.,
@@ -2839,11 +2846,13 @@ def _get_mm_dummy_batch(
28392846
dummy_mm_item = dummy_mm_data[modality][0]
28402847
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
28412848

2849+
model = cast(SupportsMultiModal, self.model)
28422850
return next(mm_kwargs_group
28432851
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
28442852
dummy_mm_items,
28452853
device=self.device,
28462854
pin_memory=self.pin_memory,
2855+
merge_by_field_config=model.merge_by_field_config,
28472856
))
28482857

28492858
@torch.inference_mode()

vllm/v1/worker/tpu_model_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from vllm.lora.layers import BaseLayerWithLoRA
3131
from vllm.model_executor.model_loader import get_model_loader
3232
from vllm.model_executor.model_loader.tpu import TPUModelLoader
33-
from vllm.model_executor.models.interfaces import supports_transcription
33+
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
34+
supports_transcription)
3435
from vllm.model_executor.models.interfaces_base import (
3536
is_pooling_model, is_text_generation_model)
3637
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -834,11 +835,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
834835
# in the same batch while still being able to benefit from batching
835836
# multimodal inputs. The proper solution should be reordering the
836837
# encoder outputs.
838+
model = cast(SupportsMultiModal, self.model)
837839
encoder_outputs = []
838840
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
839841
mm_kwargs,
840842
device=self.device,
841843
pin_memory=self.pin_memory,
844+
merge_by_field_config=model.merge_by_field_config,
842845
):
843846
# Run the encoder.
844847
# `curr_group_outputs` is either of the following:
@@ -848,7 +851,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
848851
# (feature_size, hidden_size) in case the feature size is dynamic
849852
# depending on the input multimodal items.
850853
torch_xla.sync(wait=False)
851-
curr_group_outputs = self.model.get_multimodal_embeddings(
854+
curr_group_outputs = model.get_multimodal_embeddings(
852855
**mm_kwargs_group)
853856
torch_xla.sync(wait=False)
854857

@@ -1805,11 +1808,13 @@ def _get_mm_dummy_batch(
18051808
dummy_mm_item = dummy_mm_data[modality][0]
18061809
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
18071810

1811+
model = cast(SupportsMultiModal, self.model)
18081812
return next(grouped_mm_kwargs
18091813
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
18101814
dummy_mm_items,
18111815
device=self.device,
18121816
pin_memory=self.pin_memory,
1817+
merge_by_field_config=model.merge_by_field_config,
18131818
))
18141819

18151820

0 commit comments

Comments
 (0)