Skip to content

Commit b3cf368

Browse files
authored
[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)
1 parent c8525f0 commit b3cf368

File tree

22 files changed

+249
-150
lines changed

22 files changed

+249
-150
lines changed

examples/offline_inference/vision_language.py

Lines changed: 176 additions & 118 deletions
Large diffs are not rendered by default.

vllm/model_executor/models/aria.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,9 @@ def _process_image_input(
602602

603603
return self.multi_modal_projector(image_outputs, image_attn_mask)
604604

605-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
605+
def get_multimodal_embeddings(
606+
self, **kwargs
607+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
606608
image_input = self._parse_and_validate_image_input(**kwargs)
607609
if image_input is None:
608610
return None

vllm/model_executor/models/blip2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,9 @@ def _process_image_input(self,
628628

629629
return self.language_projection(query_output)
630630

631-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
631+
def get_multimodal_embeddings(
632+
self, **kwargs
633+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
632634
image_input = self._parse_and_validate_image_input(**kwargs)
633635
if image_input is None:
634636
return None

vllm/model_executor/models/chameleon.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,9 @@ def _parse_and_validate_image_input(
986986
data=self._validate_pixel_values(pixel_values),
987987
)
988988

989-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
989+
def get_multimodal_embeddings(
990+
self, **kwargs
991+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
990992
image_input = self._parse_and_validate_image_input(**kwargs)
991993
if image_input is None:
992994
return None

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,9 @@ def _process_image_input(
606606
return self._pixel_values_to_embedding(
607607
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
608608

609-
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
609+
def get_multimodal_embeddings(
610+
self, **kwargs: object
611+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
610612
image_input = self._parse_and_validate_image_input(**kwargs)
611613
if image_input is None:
612614
return None

vllm/model_executor/models/florence2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,9 @@ def _process_image_input(
10371037
pixel_values = image_input["data"]
10381038
return self._encode_image(pixel_values)
10391039

1040-
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
1040+
def get_multimodal_embeddings(
1041+
self, **kwargs: object
1042+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
10411043
image_input = self._parse_and_validate_image_input(**kwargs)
10421044
if image_input is None:
10431045
return None

vllm/model_executor/models/fuyu.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
""" PyTorch Fuyu model."""
1919
import math
2020
from collections.abc import Iterable, Mapping, Sequence
21-
from typing import List, Literal, Optional, Set, Tuple, TypedDict
21+
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
2222

2323
import torch
2424
import torch.nn as nn
@@ -327,7 +327,9 @@ def _process_image_input(
327327
image_patches_flat)
328328
return vision_embeddings_flat.split(patches_per_image, dim=0)
329329

330-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
330+
def get_multimodal_embeddings(
331+
self, **kwargs
332+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
331333
image_input = self._parse_and_validate_image_input(**kwargs)
332334
if image_input is None:
333335
return None

vllm/model_executor/models/glm4v.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,9 @@ def _process_image_input(
595595

596596
return self.transformer.vision(pixel_values)
597597

598-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
598+
def get_multimodal_embeddings(
599+
self, **kwargs
600+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
599601
image_input = self._parse_and_validate_image_input(**kwargs)
600602
if image_input is None:
601603
return None

vllm/model_executor/models/idefics3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
617617
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
618618
self.sampler = get_sampler()
619619

620-
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
620+
def get_multimodal_embeddings(
621+
self, **kwargs
622+
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
621623
image_input = self.model._parse_and_validate_image_input(**kwargs)
622624
if image_input is None:
623625
return None

vllm/model_executor/models/interfaces.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Protocol, Type, Union, overload, runtime_checkable)
55

66
import torch
7+
from torch import Tensor
78
from typing_extensions import TypeIs, TypeVar
89

910
from vllm.logger import init_logger
@@ -15,12 +16,11 @@
1516

1617
if TYPE_CHECKING:
1718
from vllm.attention import AttentionMetadata
18-
from vllm.multimodal.inputs import NestedTensors # noqa: F401
1919
from vllm.sequence import IntermediateTensors
2020

2121
logger = init_logger(__name__)
2222

23-
T = TypeVar("T", default="NestedTensors")
23+
T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]])
2424

2525

2626
@runtime_checkable
@@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
3636
MRO of your model class.
3737
"""
3838

39-
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
39+
def get_multimodal_embeddings(self, **kwargs) -> T:
4040
"""
4141
Returns multimodal embeddings generated from multimodal kwargs
4242
to be merged with text embeddings.
@@ -59,18 +59,18 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
5959
@overload
6060
def get_input_embeddings(
6161
self,
62-
input_ids: torch.Tensor,
62+
input_ids: Tensor,
6363
multimodal_embeddings: Optional[T] = None,
6464
attn_metadata: Optional["AttentionMetadata"] = None,
65-
) -> torch.Tensor:
65+
) -> Tensor:
6666
...
6767

6868
@overload
6969
def get_input_embeddings(
7070
self,
71-
input_ids: torch.Tensor,
71+
input_ids: Tensor,
7272
multimodal_embeddings: Optional[T] = None,
73-
) -> torch.Tensor:
73+
) -> Tensor:
7474
"""
7575
Returns the input embeddings merged from the text embeddings from
7676
input_ids and the multimodal embeddings generated from multimodal
@@ -210,7 +210,7 @@ def forward(
210210
self,
211211
*,
212212
intermediate_tensors: Optional["IntermediateTensors"],
213-
) -> Union[torch.Tensor, "IntermediateTensors"]:
213+
) -> Union[Tensor, "IntermediateTensors"]:
214214
"""
215215
Accept :class:`IntermediateTensors` when PP rank > 0.
216216
@@ -237,7 +237,7 @@ def forward(
237237
self,
238238
*,
239239
intermediate_tensors: Optional["IntermediateTensors"],
240-
) -> Union[torch.Tensor, "IntermediateTensors"]:
240+
) -> Union[Tensor, "IntermediateTensors"]:
241241
...
242242

243243

0 commit comments

Comments
 (0)