Skip to content

Commit c8acd80

Browse files
[2/N] handling placeholders in merged multi-modal processor (#10485)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 4634a89 commit c8acd80

File tree

5 files changed

+975
-147
lines changed

5 files changed

+975
-147
lines changed
Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
from typing import cast
2+
3+
import pytest
4+
from transformers import BatchFeature
5+
6+
from vllm.multimodal.processing import (PromptReplacement, find_text_matches,
7+
find_token_matches, iter_token_matches,
8+
iter_token_runs, replace_text_matches)
9+
from vllm.transformers_utils.tokenizer import AnyTokenizer
10+
from vllm.utils import full_groupby
11+
12+
13+
# yapf: disable
14+
@pytest.mark.parametrize(
15+
("token_ids", "expected"),
16+
[
17+
([], []),
18+
(
19+
[32000, 32000, 32000],
20+
[{ "token_id": 32000, "start_idx": 0, "length": 3 }],
21+
),
22+
(
23+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
24+
[
25+
{ "token_id": 9833, "start_idx": 0, "length": 1 },
26+
{ "token_id": 28747, "start_idx": 1, "length": 1 },
27+
{ "token_id": 32000, "start_idx": 2, "length": 3 },
28+
{ "token_id": 9833, "start_idx": 5, "length": 1 },
29+
{ "token_id": 28747, "start_idx": 6, "length": 1 },
30+
{ "token_id": 32000, "start_idx": 7, "length": 2 },
31+
{ "token_id": 918, "start_idx": 9, "length": 1 },
32+
],
33+
),
34+
],
35+
)
36+
# yapf: enable
37+
def test_iter_token_runs(token_ids, expected):
38+
result = list(iter_token_runs(token_ids))
39+
40+
# Only displayed on error
41+
print("result:", result)
42+
43+
# Manually constructed results
44+
assert [item._asdict() for item in result] == expected
45+
46+
# Invariants
47+
assert sum(run_info.length for run_info in result) == len(token_ids)
48+
49+
50+
# yapf: disable
51+
@pytest.mark.parametrize(
52+
("token_ids", "match_ids", "expected"),
53+
[
54+
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
55+
([], [32000], []),
56+
(
57+
[32000, 32000, 32000],
58+
[32000],
59+
[
60+
{ "start_idx": 0, "end_idx": 1 },
61+
{ "start_idx": 1, "end_idx": 2 },
62+
{ "start_idx": 2, "end_idx": 3 },
63+
],
64+
),
65+
(
66+
[32000, 32000, 32000],
67+
[32000, 32000],
68+
[{ "start_idx": 0, "end_idx": 2 }],
69+
),
70+
(
71+
[32000, 32000, 32000],
72+
[32000, 32000, 32000],
73+
[{ "start_idx": 0, "end_idx": 3 }],
74+
),
75+
(
76+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
77+
[28747, 32000],
78+
[
79+
{ "start_idx": 1, "end_idx": 3 },
80+
{ "start_idx": 6, "end_idx": 8 },
81+
],
82+
),
83+
(
84+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
85+
[28747, 32000, 32000, 32000],
86+
[
87+
{ "start_idx": 1, "end_idx": 5 },
88+
],
89+
),
90+
(
91+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
92+
[28747, 0, 32000],
93+
[],
94+
),
95+
],
96+
)
97+
# yapf: enable
98+
def test_iter_token_matches(token_ids, match_ids, expected):
99+
result = list(iter_token_matches(token_ids, match_ids))
100+
101+
# Manually constructed results
102+
assert [item._asdict() for item in result] == expected
103+
104+
# Invariants
105+
match_lens = [end - start for start, end in result]
106+
print("match_lens:", match_lens) # Only displayed on error
107+
assert all(match_len == len(match_ids) for match_len in match_lens)
108+
109+
110+
# yapf: disable
111+
@pytest.mark.parametrize(
112+
("prompt", "target_by_key", "expected_by_key"),
113+
[
114+
(
115+
[],
116+
{
117+
"pattern_1": [],
118+
"pattern_2": [32000],
119+
},
120+
{
121+
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
122+
"pattern_2": [],
123+
}
124+
),
125+
(
126+
[32000, 32000, 32000, 32000],
127+
{
128+
"pattern_1": [32000],
129+
"pattern_2": [32000, 32000],
130+
"pattern_3": [32000, 32000, 32000],
131+
},
132+
{
133+
"pattern_1": [
134+
{ "start_idx": 0, "end_idx": 1 },
135+
{ "start_idx": 1, "end_idx": 2 },
136+
{ "start_idx": 2, "end_idx": 3 },
137+
{ "start_idx": 3, "end_idx": 4 },
138+
],
139+
"pattern_2": [
140+
{ "start_idx": 0, "end_idx": 2 },
141+
{ "start_idx": 2, "end_idx": 4 },
142+
],
143+
"pattern_3": [
144+
{ "start_idx": 0, "end_idx": 3 },
145+
],
146+
},
147+
),
148+
(
149+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
150+
{
151+
"pattern_1": [28747, 32000],
152+
"pattern_2": [28747, 32000, 32000, 32000],
153+
"pattern_3": [28747, 0, 32000],
154+
},
155+
{
156+
"pattern_1": [
157+
{ "start_idx": 1, "end_idx": 3 },
158+
{ "start_idx": 6, "end_idx": 8 },
159+
],
160+
"pattern_2": [
161+
{ "start_idx": 1, "end_idx": 5 },
162+
],
163+
"pattern_3": [],
164+
},
165+
),
166+
],
167+
)
168+
# yapf: enable
169+
def test_find_token_matches(prompt, target_by_key, expected_by_key):
170+
# Should not be used since there is nothing to convert to token IDs
171+
mock_tokenizer = cast(AnyTokenizer, object())
172+
173+
result = find_token_matches(
174+
prompt,
175+
[
176+
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
177+
for key, target in target_by_key.items()
178+
],
179+
)
180+
181+
# Only displayed on error
182+
print("result:", result)
183+
184+
# Manually constructed results
185+
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
186+
assert {
187+
key: [
188+
dict(start_idx=item.start_idx, end_idx=item.end_idx)
189+
for item in result_groups.get(key, [])
190+
]
191+
for key in expected_by_key
192+
} == expected_by_key
193+
194+
195+
# yapf: disable
196+
@pytest.mark.parametrize(
197+
("prompt", "target_by_key", "expected_by_key"),
198+
[
199+
# Detokenized test cases of `test_find_token_matches`
200+
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
201+
(
202+
"",
203+
{
204+
"pattern_1": "",
205+
"pattern_2": "<image>",
206+
},
207+
{
208+
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
209+
"pattern_2": [],
210+
}
211+
),
212+
(
213+
"<image><image><image><image>",
214+
{
215+
"pattern_1": "<image>",
216+
"pattern_2": "<image><image>",
217+
"pattern_3": "<image><image><image>",
218+
},
219+
{
220+
"pattern_1": [
221+
{ "start_idx": 0, "end_idx": 7 },
222+
{ "start_idx": 7, "end_idx": 14 },
223+
{ "start_idx": 14, "end_idx": 21 },
224+
{ "start_idx": 21, "end_idx": 28 },
225+
],
226+
"pattern_2": [
227+
{ "start_idx": 0, "end_idx": 14 },
228+
{ "start_idx": 14, "end_idx": 28 },
229+
],
230+
"pattern_3": [
231+
{ "start_idx": 0, "end_idx": 21 },
232+
],
233+
},
234+
),
235+
(
236+
"Image:<image><image><image>Image:<image><image>!",
237+
{
238+
"pattern_1": "Image:<image>",
239+
"pattern_2": "Image:<image><image><image>",
240+
"pattern_3": "Image:<unk><image>",
241+
},
242+
{
243+
"pattern_1": [
244+
{ "start_idx": 0, "end_idx": 13 },
245+
{ "start_idx": 27, "end_idx": 40 },
246+
],
247+
"pattern_2": [
248+
{ "start_idx": 0, "end_idx": 27 },
249+
],
250+
"pattern_3": [],
251+
},
252+
),
253+
# Test regex escape
254+
(
255+
"<|image|><image><|image|><image>",
256+
{
257+
"pattern_1": "<|image|>",
258+
"pattern_2": "<|image|><image>",
259+
"pattern_3": "<|image|><image><|image|>",
260+
},
261+
{
262+
"pattern_1": [
263+
{ "start_idx": 0, "end_idx": 9 },
264+
{ "start_idx": 16, "end_idx": 25 },
265+
],
266+
"pattern_2": [
267+
{ "start_idx": 0, "end_idx": 16 },
268+
{ "start_idx": 16, "end_idx": 32 },
269+
],
270+
"pattern_3": [
271+
{ "start_idx": 0, "end_idx": 25 },
272+
],
273+
},
274+
),
275+
],
276+
)
277+
# yapf: enable
278+
def test_find_text_matches(prompt, target_by_key, expected_by_key):
279+
# Should not be used since there is nothing to convert to text
280+
mock_tokenizer = cast(AnyTokenizer, object())
281+
282+
result = find_text_matches(
283+
prompt,
284+
[
285+
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
286+
for key, target in target_by_key.items()
287+
],
288+
)
289+
290+
# Only displayed on error
291+
print("result:", result)
292+
293+
# Manually constructed results
294+
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
295+
assert {
296+
key: [
297+
dict(start_idx=item.start_idx, end_idx=item.end_idx)
298+
for item in result_groups.get(key, [])
299+
]
300+
for key in expected_by_key
301+
} == expected_by_key
302+
303+
304+
# yapf: disable
305+
@pytest.mark.parametrize(
306+
("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"),
307+
[
308+
(
309+
"Image:<image>Image:<image><image>!",
310+
{
311+
# We use `<image>` before `Image:` to test matches that
312+
# occur out of order
313+
"pattern_1": "<image>",
314+
"pattern_2": "Image:",
315+
"pattern_3": "!",
316+
},
317+
{
318+
# Test whether target is confused with repl_unit
319+
"pattern_1": ("<image><image>", 1),
320+
# Test empty repl_unit
321+
"pattern_2": ("", 1),
322+
# Test multiple repl_count
323+
"pattern_3": ("?", 2),
324+
},
325+
{
326+
# Test no replacement
327+
0: "Image:<image>Image:<image><image>!",
328+
# Test single replacement
329+
1: "<image><image>Image:<image><image>??",
330+
# Test repeated replacement
331+
2: "<image><image><image><image><image>??",
332+
},
333+
),
334+
]
335+
)
336+
# yapf: enable
337+
def test_find_replace_text(
338+
prompt,
339+
target_by_key,
340+
repl_by_key,
341+
expected_by_mm_count,
342+
):
343+
# Should not be used since there is nothing to convert to text
344+
mock_tokenizer = cast(AnyTokenizer, object())
345+
346+
matches = find_text_matches(
347+
prompt,
348+
[
349+
PromptReplacement(target, *repl_by_key[key]) \
350+
.bind(key, mock_tokenizer)
351+
for key, target in target_by_key.items()
352+
],
353+
)
354+
result_by_mm_count = {
355+
mm_count: replace_text_matches(
356+
prompt,
357+
matches,
358+
{key: list(range(mm_count))
359+
for key in repl_by_key},
360+
BatchFeature(),
361+
)
362+
for mm_count in expected_by_mm_count
363+
}
364+
365+
# Only displayed on error
366+
print("matches:", matches)
367+
print("result_by_mm_count:", result_by_mm_count)
368+
369+
# Manually constructed results
370+
assert result_by_mm_count == expected_by_mm_count

tests/multimodal/test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model):
139139
2,
140140
"<image><image><image>",
141141
[32000, 32000, 32000],
142-
[{ "offset": 0, "length": 2 }]),
142+
[{ "offset": 0, "length": 2 }],
143+
),
143144
(
144145
"<image><image>",
145146
[3, 2],

vllm/multimodal/inputs.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
203203
"""The type of inputs."""
204204

205205
prompt: str
206-
"""
207-
The original, unprocessed prompt text.
208-
209-
Note:
210-
Since prompt text is not required by vLLM internals, we leave this
211-
unprocessed to save CPU computation. You can still call
212-
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
213-
"""
206+
"""The processed prompt text."""
214207

215208
prompt_token_ids: List[int]
216209
"""The processed token IDs which includes placeholder tokens."""

0 commit comments

Comments
 (0)