|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | """ |
4 | 4 | Demonstrate prompting of text-to-text |
5 | | -encoder/decoder models, specifically BART |
| 5 | +encoder/decoder models, specifically BART and mBART. |
| 6 | +
|
| 7 | +This script is refactored to allow model selection via command-line arguments. |
6 | 8 | """ |
7 | 9 |
|
| 10 | +import argparse |
| 11 | +from typing import NamedTuple, Optional |
| 12 | + |
8 | 13 | from vllm import LLM, SamplingParams |
9 | 14 | from vllm.inputs import ( |
10 | 15 | ExplicitEncoderDecoderPrompt, |
|
14 | 19 | ) |
15 | 20 |
|
16 | 21 |
|
17 | | -def create_prompts(tokenizer): |
18 | | - # Test prompts |
19 | | - # |
20 | | - # This section shows all of the valid ways to prompt an |
21 | | - # encoder/decoder model. |
22 | | - # |
23 | | - # - Helpers for building prompts |
24 | | - text_prompt_raw = "Hello, my name is" |
25 | | - text_prompt = TextPrompt(prompt="The president of the United States is") |
26 | | - tokens_prompt = TokensPrompt( |
27 | | - prompt_token_ids=tokenizer.encode(prompt="The capital of France is") |
| 22 | +class ModelRequestData(NamedTuple): |
| 23 | + """ |
| 24 | + Holds the configuration for a specific model, including its |
| 25 | + HuggingFace ID and the prompts to use for the demo. |
| 26 | + """ |
| 27 | + |
| 28 | + model_id: str |
| 29 | + encoder_prompts: list |
| 30 | + decoder_prompts: list |
| 31 | + hf_overrides: Optional[dict] = None |
| 32 | + |
| 33 | + |
| 34 | +def get_bart_config() -> ModelRequestData: |
| 35 | + """ |
| 36 | + Returns the configuration for facebook/bart-large-cnn. |
| 37 | + This uses the exact test cases from the original script. |
| 38 | + """ |
| 39 | + encoder_prompts = [ |
| 40 | + "Hello, my name is", |
| 41 | + "The president of the United States is", |
| 42 | + "The capital of France is", |
| 43 | + "An encoder prompt", |
| 44 | + ] |
| 45 | + decoder_prompts = [ |
| 46 | + "A decoder prompt", |
| 47 | + "Another decoder prompt", |
| 48 | + ] |
| 49 | + return ModelRequestData( |
| 50 | + model_id="facebook/bart-large-cnn", |
| 51 | + encoder_prompts=encoder_prompts, |
| 52 | + decoder_prompts=decoder_prompts, |
28 | 53 | ) |
29 | | - # - Pass a single prompt to encoder/decoder model |
30 | | - # (implicitly encoder input prompt); |
31 | | - # decoder input prompt is assumed to be None |
32 | | - |
33 | | - single_text_prompt_raw = text_prompt_raw # Pass a string directly |
34 | | - single_text_prompt = text_prompt # Pass a TextPrompt |
35 | | - single_tokens_prompt = tokens_prompt # Pass a TokensPrompt |
36 | | - |
37 | | - # ruff: noqa: E501 |
38 | | - # - Pass explicit encoder and decoder input prompts within one data structure. |
39 | | - # Encoder and decoder prompts can both independently be text or tokens, with |
40 | | - # no requirement that they be the same prompt type. Some example prompt-type |
41 | | - # combinations are shown below, note that these are not exhaustive. |
42 | | - |
43 | | - enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( |
44 | | - # Pass encoder prompt string directly, & |
45 | | - # pass decoder prompt tokens |
46 | | - encoder_prompt=single_text_prompt_raw, |
47 | | - decoder_prompt=single_tokens_prompt, |
48 | | - ) |
49 | | - enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( |
50 | | - # Pass TextPrompt to encoder, and |
51 | | - # pass decoder prompt string directly |
52 | | - encoder_prompt=single_text_prompt, |
53 | | - decoder_prompt=single_text_prompt_raw, |
| 54 | + |
| 55 | + |
| 56 | +def get_mbart_config() -> ModelRequestData: |
| 57 | + """ |
| 58 | + Returns the configuration for facebook/mbart-large-en-ro. |
| 59 | + This uses prompts suitable for an English-to-Romanian translation task. |
| 60 | + """ |
| 61 | + encoder_prompts = [ |
| 62 | + "The quick brown fox jumps over the lazy dog.", |
| 63 | + "How are you today?", |
| 64 | + ] |
| 65 | + decoder_prompts = ["", ""] |
| 66 | + hf_overrides = {"architectures": ["MBartForConditionalGeneration"]} |
| 67 | + return ModelRequestData( |
| 68 | + model_id="facebook/mbart-large-en-ro", |
| 69 | + encoder_prompts=encoder_prompts, |
| 70 | + decoder_prompts=decoder_prompts, |
| 71 | + hf_overrides=hf_overrides, |
54 | 72 | ) |
55 | | - enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( |
56 | | - # Pass encoder prompt tokens directly, and |
57 | | - # pass TextPrompt to decoder |
58 | | - encoder_prompt=single_tokens_prompt, |
59 | | - decoder_prompt=single_text_prompt, |
| 73 | + |
| 74 | + |
| 75 | +MODEL_GETTERS = { |
| 76 | + "bart": get_bart_config, |
| 77 | + "mbart": get_mbart_config, |
| 78 | +} |
| 79 | + |
| 80 | + |
| 81 | +def create_all_prompt_types( |
| 82 | + encoder_prompts_raw: list, |
| 83 | + decoder_prompts_raw: list, |
| 84 | + tokenizer, |
| 85 | +) -> list: |
| 86 | + """ |
| 87 | + Generates a list of diverse prompt types for demonstration. |
| 88 | + This function is generic and uses the provided raw prompts |
| 89 | + to create various vLLM input objects. |
| 90 | + """ |
| 91 | + text_prompt_raw = encoder_prompts_raw[0] |
| 92 | + text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)]) |
| 93 | + tokens_prompt = TokensPrompt( |
| 94 | + prompt_token_ids=tokenizer.encode( |
| 95 | + encoder_prompts_raw[2 % len(encoder_prompts_raw)] |
| 96 | + ) |
60 | 97 | ) |
61 | 98 |
|
62 | | - # - Finally, here's a useful helper function for zipping encoder and |
63 | | - # decoder prompts together into a list of ExplicitEncoderDecoderPrompt |
64 | | - # instances |
| 99 | + decoder_tokens_prompt = TokensPrompt( |
| 100 | + prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0]) |
| 101 | + ) |
| 102 | + single_prompt_examples = [ |
| 103 | + text_prompt_raw, |
| 104 | + text_prompt, |
| 105 | + tokens_prompt, |
| 106 | + ] |
| 107 | + explicit_pair_examples = [ |
| 108 | + ExplicitEncoderDecoderPrompt( |
| 109 | + encoder_prompt=text_prompt_raw, |
| 110 | + decoder_prompt=decoder_tokens_prompt, |
| 111 | + ), |
| 112 | + ExplicitEncoderDecoderPrompt( |
| 113 | + encoder_prompt=text_prompt, |
| 114 | + decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)], |
| 115 | + ), |
| 116 | + ExplicitEncoderDecoderPrompt( |
| 117 | + encoder_prompt=tokens_prompt, |
| 118 | + decoder_prompt=text_prompt, |
| 119 | + ), |
| 120 | + ] |
65 | 121 | zipped_prompt_list = zip_enc_dec_prompts( |
66 | | - ["An encoder prompt", "Another encoder prompt"], |
67 | | - ["A decoder prompt", "Another decoder prompt"], |
| 122 | + encoder_prompts_raw, |
| 123 | + decoder_prompts_raw, |
68 | 124 | ) |
| 125 | + return single_prompt_examples + explicit_pair_examples + zipped_prompt_list |
69 | 126 |
|
70 | | - # - Let's put all of the above example prompts together into one list |
71 | | - # which we will pass to the encoder/decoder LLM. |
72 | | - return [ |
73 | | - single_text_prompt_raw, |
74 | | - single_text_prompt, |
75 | | - single_tokens_prompt, |
76 | | - enc_dec_prompt1, |
77 | | - enc_dec_prompt2, |
78 | | - enc_dec_prompt3, |
79 | | - ] + zipped_prompt_list |
80 | 127 |
|
81 | | - |
82 | | -# Create a sampling params object. |
83 | | -def create_sampling_params(): |
| 128 | +def create_sampling_params() -> SamplingParams: |
| 129 | + """Create a sampling params object.""" |
84 | 130 | return SamplingParams( |
85 | 131 | temperature=0, |
86 | 132 | top_p=1.0, |
87 | 133 | min_tokens=0, |
88 | | - max_tokens=20, |
| 134 | + max_tokens=30, |
89 | 135 | ) |
90 | 136 |
|
91 | 137 |
|
92 | | -# Print the outputs. |
93 | | -def print_outputs(outputs): |
94 | | - print("-" * 50) |
| 138 | +def print_outputs(outputs: list): |
| 139 | + """Formats and prints the generation outputs.""" |
| 140 | + print("-" * 80) |
95 | 141 | for i, output in enumerate(outputs): |
96 | 142 | prompt = output.prompt |
97 | 143 | encoder_prompt = output.encoder_prompt |
98 | 144 | generated_text = output.outputs[0].text |
99 | 145 | print(f"Output {i + 1}:") |
100 | | - print( |
101 | | - f"Encoder prompt: {encoder_prompt!r}\n" |
102 | | - f"Decoder prompt: {prompt!r}\n" |
103 | | - f"Generated text: {generated_text!r}" |
| 146 | + print(f"Encoder Prompt: {encoder_prompt!r}") |
| 147 | + print(f"Decoder Prompt: {prompt!r}") |
| 148 | + print(f"Generated Text: {generated_text!r}") |
| 149 | + print("-" * 80) |
| 150 | + |
| 151 | + |
| 152 | +def main(args): |
| 153 | + """Main execution function.""" |
| 154 | + model_key = args.model |
| 155 | + if model_key not in MODEL_GETTERS: |
| 156 | + raise ValueError( |
| 157 | + f"Unknown model: {model_key}. " |
| 158 | + f"Available models: {list(MODEL_GETTERS.keys())}" |
104 | 159 | ) |
105 | | - print("-" * 50) |
106 | | - |
107 | | - |
108 | | -def main(): |
109 | | - dtype = "float" |
| 160 | + config_getter = MODEL_GETTERS[model_key] |
| 161 | + model_config = config_getter() |
110 | 162 |
|
111 | | - # Create a BART encoder/decoder model instance |
| 163 | + print(f"🚀 Running demo for model: {model_config.model_id}") |
112 | 164 | llm = LLM( |
113 | | - model="facebook/bart-large-cnn", |
114 | | - dtype=dtype, |
| 165 | + model=model_config.model_id, |
| 166 | + dtype="float", |
| 167 | + hf_overrides=model_config.hf_overrides, |
115 | 168 | ) |
116 | | - |
117 | | - # Get BART tokenizer |
118 | 169 | tokenizer = llm.llm_engine.get_tokenizer_group() |
119 | | - |
120 | | - prompts = create_prompts(tokenizer) |
| 170 | + prompts = create_all_prompt_types( |
| 171 | + encoder_prompts_raw=model_config.encoder_prompts, |
| 172 | + decoder_prompts_raw=model_config.decoder_prompts, |
| 173 | + tokenizer=tokenizer, |
| 174 | + ) |
121 | 175 | sampling_params = create_sampling_params() |
122 | | - |
123 | | - # Generate output tokens from the prompts. The output is a list of |
124 | | - # RequestOutput objects that contain the prompt, generated |
125 | | - # text, and other information. |
126 | 176 | outputs = llm.generate(prompts, sampling_params) |
127 | | - |
128 | 177 | print_outputs(outputs) |
129 | 178 |
|
130 | 179 |
|
131 | 180 | if __name__ == "__main__": |
132 | | - main() |
| 181 | + parser = argparse.ArgumentParser( |
| 182 | + description="A flexible demo for vLLM encoder-decoder models." |
| 183 | + ) |
| 184 | + parser.add_argument( |
| 185 | + "--model", |
| 186 | + "-m", |
| 187 | + type=str, |
| 188 | + default="bart", |
| 189 | + choices=MODEL_GETTERS.keys(), |
| 190 | + help="The short name of the model to run.", |
| 191 | + ) |
| 192 | + args = parser.parse_args() |
| 193 | + main(args) |
0 commit comments