44 Protocol , Type , Union , overload , runtime_checkable )
55
66import torch
7+ from torch import Tensor
78from typing_extensions import TypeIs , TypeVar
89
910from vllm .logger import init_logger
1516
1617if TYPE_CHECKING :
1718 from vllm .attention import AttentionMetadata
18- from vllm .multimodal .inputs import NestedTensors # noqa: F401
1919 from vllm .sequence import IntermediateTensors
2020
2121logger = init_logger (__name__ )
2222
23- T = TypeVar ("T" , default = "NestedTensors" )
23+ T = TypeVar ("T" , default = Union [ list [ Tensor ], Tensor , tuple [ Tensor , ...]] )
2424
2525
2626@runtime_checkable
@@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
3636 MRO of your model class.
3737 """
3838
39- def get_multimodal_embeddings (self , ** kwargs ) -> Optional [ T ] :
39+ def get_multimodal_embeddings (self , ** kwargs ) -> T :
4040 """
4141 Returns multimodal embeddings generated from multimodal kwargs
4242 to be merged with text embeddings.
@@ -59,18 +59,18 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
5959 @overload
6060 def get_input_embeddings (
6161 self ,
62- input_ids : torch . Tensor ,
62+ input_ids : Tensor ,
6363 multimodal_embeddings : Optional [T ] = None ,
6464 attn_metadata : Optional ["AttentionMetadata" ] = None ,
65- ) -> torch . Tensor :
65+ ) -> Tensor :
6666 ...
6767
6868 @overload
6969 def get_input_embeddings (
7070 self ,
71- input_ids : torch . Tensor ,
71+ input_ids : Tensor ,
7272 multimodal_embeddings : Optional [T ] = None ,
73- ) -> torch . Tensor :
73+ ) -> Tensor :
7474 """
7575 Returns the input embeddings merged from the text embeddings from
7676 input_ids and the multimodal embeddings generated from multimodal
@@ -210,7 +210,7 @@ def forward(
210210 self ,
211211 * ,
212212 intermediate_tensors : Optional ["IntermediateTensors" ],
213- ) -> Union [torch . Tensor , "IntermediateTensors" ]:
213+ ) -> Union [Tensor , "IntermediateTensors" ]:
214214 """
215215 Accept :class:`IntermediateTensors` when PP rank > 0.
216216
@@ -237,7 +237,7 @@ def forward(
237237 self ,
238238 * ,
239239 intermediate_tensors : Optional ["IntermediateTensors" ],
240- ) -> Union [torch . Tensor , "IntermediateTensors" ]:
240+ ) -> Union [Tensor , "IntermediateTensors" ]:
241241 ...
242242
243243
0 commit comments