@@ -457,17 +457,13 @@ def __init__(self,
457457 self .enable_prompt_adapter = (self .runner .prompt_adapter_config
458458 is not None )
459459 self .multi_modal_input_mapper = self .runner .multi_modal_input_mapper
460- self .finished_requests_ids = finished_requests_ids
461460 self .decode_only = True
462461
463- # Intermediate data (data in CPU before going to GPU) for
464- # the current sequence group.
465- self .inter_data_list : List [
466- ModelInputForGPUBuilder .InterDataForSeqGroup ] = []
467-
468462 # Attention metadata inputs.
469- self .attn_metadata_builder = self .attn_backend .make_metadata_builder (
470- weakref .proxy (self ))
463+ if self .attn_backend is not None :
464+ # spec decode (e.g. Medusa) does not have atten backend
465+ self .attn_metadata_builder = self .attn_backend .get_builder_cls ()(
466+ weakref .proxy (self ))
471467
472468 # Engine/Model configurations.
473469 self .chunked_prefill_enabled = (
@@ -479,6 +475,17 @@ def __init__(self,
479475 self .block_aligned_sliding_window = \
480476 self .sliding_window_blocks * self .block_size
481477
478+ def prepare (self ,
479+ finished_requests_ids : Optional [List [str ]] = None ) -> None :
480+ self .finished_requests_ids = finished_requests_ids
481+
482+ # Intermediate data (data in CPU before going to GPU) for
483+ # the current sequence group.
484+ self .inter_data_list : List [
485+ ModelInputForGPUBuilder .InterDataForSeqGroup ] = []
486+
487+ self .attn_metadata_builder .prepare ()
488+
482489 def _compute_lens (self , inter_data : InterDataForSeqGroup , seq_idx : int ,
483490 seq_group_metadata : SequenceGroupMetadata ):
484491 """Compute context length, sequence length and tokens
@@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
9931000 """
9941001 _model_input_cls : Type [TModelInputForGPU ]
9951002 _builder_cls : Type [ModelInputForGPUBuilder ]
1003+ builder : ModelInputForGPUBuilder
9961004
9971005 def __init__ (
9981006 self ,
@@ -1093,6 +1101,10 @@ def __init__(
10931101 SamplingMetadataCache () \
10941102 if self .parallel_config .pipeline_parallel_size == 1 else None
10951103
1104+ if hasattr (self , "_builder_cls" ):
1105+ # multi-step model runner does not have `_builder_cls`
1106+ self .builder = self ._builder_cls (weakref .proxy (self ))
1107+
10961108 def load_model (self ) -> None :
10971109 logger .info ("Starting to load model %s..." , self .model_config .model )
10981110 with DeviceMemoryProfiler () as m :
@@ -1226,13 +1238,13 @@ def _prepare_model_input_tensors(
12261238
12271239 If cuda graph is required, this API automatically pads inputs.
12281240 """
1229- builder = self ._builder_cls ( weakref . proxy ( self ), finished_requests_ids )
1241+ self .builder . prepare ( finished_requests_ids )
12301242 for seq_group_metadata in seq_group_metadata_list :
1231- builder .add_seq_group (seq_group_metadata )
1243+ self . builder .add_seq_group (seq_group_metadata )
12321244
1233- builder .reset_cached_inter_data ()
1245+ self . builder .reset_cached_inter_data ()
12341246
1235- return builder .build () # type: ignore
1247+ return self . builder .build () # type: ignore
12361248
12371249 @contextmanager
12381250 def set_in_profile_run (self ):
0 commit comments