Skip to content

Commit 829bbd7

Browse files
authored
[New Model]mBART model (vllm-project#22883)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
1 parent 4dff91c commit 829bbd7

File tree

6 files changed

+716
-91
lines changed

6 files changed

+716
-91
lines changed

docs/models/supported_models.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ th {
330330
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
331331
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
332332
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
333+
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
333334
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
334335
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
335336
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
@@ -418,6 +419,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th
418419
!!! note
419420
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
420421

422+
!!! note
423+
Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture.
424+
421425
### Pooling Models
422426

423427
See [this page](./pooling_models.md) for more information on how to use pooling models.

examples/offline_inference/encoder_decoder.py

Lines changed: 147 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""
44
Demonstrate prompting of text-to-text
5-
encoder/decoder models, specifically BART
5+
encoder/decoder models, specifically BART and mBART.
6+
7+
This script is refactored to allow model selection via command-line arguments.
68
"""
79

10+
import argparse
11+
from typing import NamedTuple, Optional
12+
813
from vllm import LLM, SamplingParams
914
from vllm.inputs import (
1015
ExplicitEncoderDecoderPrompt,
@@ -14,119 +19,175 @@
1419
)
1520

1621

17-
def create_prompts(tokenizer):
18-
# Test prompts
19-
#
20-
# This section shows all of the valid ways to prompt an
21-
# encoder/decoder model.
22-
#
23-
# - Helpers for building prompts
24-
text_prompt_raw = "Hello, my name is"
25-
text_prompt = TextPrompt(prompt="The president of the United States is")
26-
tokens_prompt = TokensPrompt(
27-
prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
22+
class ModelRequestData(NamedTuple):
23+
"""
24+
Holds the configuration for a specific model, including its
25+
HuggingFace ID and the prompts to use for the demo.
26+
"""
27+
28+
model_id: str
29+
encoder_prompts: list
30+
decoder_prompts: list
31+
hf_overrides: Optional[dict] = None
32+
33+
34+
def get_bart_config() -> ModelRequestData:
35+
"""
36+
Returns the configuration for facebook/bart-large-cnn.
37+
This uses the exact test cases from the original script.
38+
"""
39+
encoder_prompts = [
40+
"Hello, my name is",
41+
"The president of the United States is",
42+
"The capital of France is",
43+
"An encoder prompt",
44+
]
45+
decoder_prompts = [
46+
"A decoder prompt",
47+
"Another decoder prompt",
48+
]
49+
return ModelRequestData(
50+
model_id="facebook/bart-large-cnn",
51+
encoder_prompts=encoder_prompts,
52+
decoder_prompts=decoder_prompts,
2853
)
29-
# - Pass a single prompt to encoder/decoder model
30-
# (implicitly encoder input prompt);
31-
# decoder input prompt is assumed to be None
32-
33-
single_text_prompt_raw = text_prompt_raw # Pass a string directly
34-
single_text_prompt = text_prompt # Pass a TextPrompt
35-
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
36-
37-
# ruff: noqa: E501
38-
# - Pass explicit encoder and decoder input prompts within one data structure.
39-
# Encoder and decoder prompts can both independently be text or tokens, with
40-
# no requirement that they be the same prompt type. Some example prompt-type
41-
# combinations are shown below, note that these are not exhaustive.
42-
43-
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
44-
# Pass encoder prompt string directly, &
45-
# pass decoder prompt tokens
46-
encoder_prompt=single_text_prompt_raw,
47-
decoder_prompt=single_tokens_prompt,
48-
)
49-
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
50-
# Pass TextPrompt to encoder, and
51-
# pass decoder prompt string directly
52-
encoder_prompt=single_text_prompt,
53-
decoder_prompt=single_text_prompt_raw,
54+
55+
56+
def get_mbart_config() -> ModelRequestData:
57+
"""
58+
Returns the configuration for facebook/mbart-large-en-ro.
59+
This uses prompts suitable for an English-to-Romanian translation task.
60+
"""
61+
encoder_prompts = [
62+
"The quick brown fox jumps over the lazy dog.",
63+
"How are you today?",
64+
]
65+
decoder_prompts = ["", ""]
66+
hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
67+
return ModelRequestData(
68+
model_id="facebook/mbart-large-en-ro",
69+
encoder_prompts=encoder_prompts,
70+
decoder_prompts=decoder_prompts,
71+
hf_overrides=hf_overrides,
5472
)
55-
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
56-
# Pass encoder prompt tokens directly, and
57-
# pass TextPrompt to decoder
58-
encoder_prompt=single_tokens_prompt,
59-
decoder_prompt=single_text_prompt,
73+
74+
75+
MODEL_GETTERS = {
76+
"bart": get_bart_config,
77+
"mbart": get_mbart_config,
78+
}
79+
80+
81+
def create_all_prompt_types(
82+
encoder_prompts_raw: list,
83+
decoder_prompts_raw: list,
84+
tokenizer,
85+
) -> list:
86+
"""
87+
Generates a list of diverse prompt types for demonstration.
88+
This function is generic and uses the provided raw prompts
89+
to create various vLLM input objects.
90+
"""
91+
text_prompt_raw = encoder_prompts_raw[0]
92+
text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
93+
tokens_prompt = TokensPrompt(
94+
prompt_token_ids=tokenizer.encode(
95+
encoder_prompts_raw[2 % len(encoder_prompts_raw)]
96+
)
6097
)
6198

62-
# - Finally, here's a useful helper function for zipping encoder and
63-
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
64-
# instances
99+
decoder_tokens_prompt = TokensPrompt(
100+
prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
101+
)
102+
single_prompt_examples = [
103+
text_prompt_raw,
104+
text_prompt,
105+
tokens_prompt,
106+
]
107+
explicit_pair_examples = [
108+
ExplicitEncoderDecoderPrompt(
109+
encoder_prompt=text_prompt_raw,
110+
decoder_prompt=decoder_tokens_prompt,
111+
),
112+
ExplicitEncoderDecoderPrompt(
113+
encoder_prompt=text_prompt,
114+
decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
115+
),
116+
ExplicitEncoderDecoderPrompt(
117+
encoder_prompt=tokens_prompt,
118+
decoder_prompt=text_prompt,
119+
),
120+
]
65121
zipped_prompt_list = zip_enc_dec_prompts(
66-
["An encoder prompt", "Another encoder prompt"],
67-
["A decoder prompt", "Another decoder prompt"],
122+
encoder_prompts_raw,
123+
decoder_prompts_raw,
68124
)
125+
return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
69126

70-
# - Let's put all of the above example prompts together into one list
71-
# which we will pass to the encoder/decoder LLM.
72-
return [
73-
single_text_prompt_raw,
74-
single_text_prompt,
75-
single_tokens_prompt,
76-
enc_dec_prompt1,
77-
enc_dec_prompt2,
78-
enc_dec_prompt3,
79-
] + zipped_prompt_list
80127

81-
82-
# Create a sampling params object.
83-
def create_sampling_params():
128+
def create_sampling_params() -> SamplingParams:
129+
"""Create a sampling params object."""
84130
return SamplingParams(
85131
temperature=0,
86132
top_p=1.0,
87133
min_tokens=0,
88-
max_tokens=20,
134+
max_tokens=30,
89135
)
90136

91137

92-
# Print the outputs.
93-
def print_outputs(outputs):
94-
print("-" * 50)
138+
def print_outputs(outputs: list):
139+
"""Formats and prints the generation outputs."""
140+
print("-" * 80)
95141
for i, output in enumerate(outputs):
96142
prompt = output.prompt
97143
encoder_prompt = output.encoder_prompt
98144
generated_text = output.outputs[0].text
99145
print(f"Output {i + 1}:")
100-
print(
101-
f"Encoder prompt: {encoder_prompt!r}\n"
102-
f"Decoder prompt: {prompt!r}\n"
103-
f"Generated text: {generated_text!r}"
146+
print(f"Encoder Prompt: {encoder_prompt!r}")
147+
print(f"Decoder Prompt: {prompt!r}")
148+
print(f"Generated Text: {generated_text!r}")
149+
print("-" * 80)
150+
151+
152+
def main(args):
153+
"""Main execution function."""
154+
model_key = args.model
155+
if model_key not in MODEL_GETTERS:
156+
raise ValueError(
157+
f"Unknown model: {model_key}. "
158+
f"Available models: {list(MODEL_GETTERS.keys())}"
104159
)
105-
print("-" * 50)
106-
107-
108-
def main():
109-
dtype = "float"
160+
config_getter = MODEL_GETTERS[model_key]
161+
model_config = config_getter()
110162

111-
# Create a BART encoder/decoder model instance
163+
print(f"🚀 Running demo for model: {model_config.model_id}")
112164
llm = LLM(
113-
model="facebook/bart-large-cnn",
114-
dtype=dtype,
165+
model=model_config.model_id,
166+
dtype="float",
167+
hf_overrides=model_config.hf_overrides,
115168
)
116-
117-
# Get BART tokenizer
118169
tokenizer = llm.llm_engine.get_tokenizer_group()
119-
120-
prompts = create_prompts(tokenizer)
170+
prompts = create_all_prompt_types(
171+
encoder_prompts_raw=model_config.encoder_prompts,
172+
decoder_prompts_raw=model_config.decoder_prompts,
173+
tokenizer=tokenizer,
174+
)
121175
sampling_params = create_sampling_params()
122-
123-
# Generate output tokens from the prompts. The output is a list of
124-
# RequestOutput objects that contain the prompt, generated
125-
# text, and other information.
126176
outputs = llm.generate(prompts, sampling_params)
127-
128177
print_outputs(outputs)
129178

130179

131180
if __name__ == "__main__":
132-
main()
181+
parser = argparse.ArgumentParser(
182+
description="A flexible demo for vLLM encoder-decoder models."
183+
)
184+
parser.add_argument(
185+
"--model",
186+
"-m",
187+
type=str,
188+
default="bart",
189+
choices=MODEL_GETTERS.keys(),
190+
help="The short name of the model to run.",
191+
)
192+
args = parser.parse_args()
193+
main(args)

0 commit comments

Comments
 (0)