4040from vllm .model_executor .layers .mamba .abstract import MambaBase
4141from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
4242from vllm .model_executor .model_loader import TensorizerLoader , get_model_loader
43- from vllm .model_executor .models .interfaces import (is_mixture_of_experts ,
43+ from vllm .model_executor .models .interfaces import (SupportsMultiModal ,
44+ is_mixture_of_experts ,
4445 supports_eagle3 ,
4546 supports_mrope ,
4647 supports_transcription )
@@ -777,11 +778,13 @@ def _extract_mm_kwargs(
777778 mm_kwargs .append (feature .data )
778779
779780 # Input all modalities at once
781+ model = cast (SupportsMultiModal , self .model )
780782 mm_kwargs_combined : BatchedTensorInputs = {}
781783 for _ , _ , mm_kwargs_group in group_mm_kwargs_by_modality (
782784 mm_kwargs ,
783785 device = self .device ,
784786 pin_memory = self .pin_memory ,
787+ merge_by_field_config = model .merge_by_field_config ,
785788 ):
786789 mm_kwargs_combined .update (mm_kwargs_group )
787790
@@ -1525,11 +1528,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
15251528 # in the same batch while still being able to benefit from batching
15261529 # multimodal inputs. The proper solution should be reordering the
15271530 # encoder outputs.
1531+ model = cast (SupportsMultiModal , self .model )
15281532 encoder_outputs = []
15291533 for _ , num_items , mm_kwargs_group in group_mm_kwargs_by_modality (
15301534 mm_kwargs ,
15311535 device = self .device ,
15321536 pin_memory = self .pin_memory ,
1537+ merge_by_field_config = model .merge_by_field_config ,
15331538 ):
15341539 # Run the encoder.
15351540 # `curr_group_outputs` is either of the following:
@@ -1538,7 +1543,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
15381543 # 2. A list or tuple (length: num_items) of tensors, each of shape
15391544 # (feature_size, hidden_size) in case the feature size is dynamic
15401545 # depending on the input multimodal items.
1541- curr_group_outputs = self . model .get_multimodal_embeddings (
1546+ curr_group_outputs = model .get_multimodal_embeddings (
15421547 ** mm_kwargs_group )
15431548
15441549 sanity_check_mm_encoder_outputs (
@@ -1623,11 +1628,13 @@ def _extract_encoder_inputs(
16231628 return {}
16241629
16251630 # Group MM kwargs by modality and extract features
1631+ model = cast (SupportsMultiModal , self .model )
16261632 encoder_features = {}
16271633 for _ , _ , mm_kwargs_group in group_mm_kwargs_by_modality (
16281634 mm_kwargs ,
16291635 device = self .device ,
16301636 pin_memory = self .pin_memory ,
1637+ merge_by_field_config = model .merge_by_field_config ,
16311638 ):
16321639 # Add the grouped features to encoder_features dict
16331640 # This allows the model to receive them as kwargs (e.g.,
@@ -2839,11 +2846,13 @@ def _get_mm_dummy_batch(
28392846 dummy_mm_item = dummy_mm_data [modality ][0 ]
28402847 dummy_mm_items = [dummy_mm_item ] * max_items_per_batch
28412848
2849+ model = cast (SupportsMultiModal , self .model )
28422850 return next (mm_kwargs_group
28432851 for _ , _ , mm_kwargs_group in group_mm_kwargs_by_modality (
28442852 dummy_mm_items ,
28452853 device = self .device ,
28462854 pin_memory = self .pin_memory ,
2855+ merge_by_field_config = model .merge_by_field_config ,
28472856 ))
28482857
28492858 @torch .inference_mode ()
0 commit comments