Skip to content

Commit c0d8a9f

Browse files
author
xusenlin
committed
Update build_qwen_chat_input
1 parent bab2a6c commit c0d8a9f

File tree

3 files changed

+77
-119
lines changed

3 files changed

+77
-119
lines changed

api/core/default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def build_chat_inputs(
253253
)
254254
elif check_is_qwen(self.model):
255255
inputs = build_qwen_chat_input(
256-
self.tokenizer, messages, self.context_len, max_new_tokens, functions, tools,
256+
self.tokenizer, messages, functions=functions, tools=tools,
257257
)
258258
elif check_is_xverse(self.model):
259259
inputs = build_xverse_chat_input(

api/core/vllm_engine.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,8 @@ def apply_chat_template(
8585
return build_qwen_chat_input(
8686
self.tokenizer,
8787
messages,
88-
self.max_model_len,
89-
max_tokens,
90-
functions,
91-
tools,
88+
functions=functions,
89+
tools=tools,
9290
)
9391
else:
9492
return self.prompt_adapter.apply_chat_template(messages)

api/generation/qwen.py

Lines changed: 74 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import json
2-
import re
32
from copy import deepcopy
43
from typing import List, Union, Optional, Dict, Any, Tuple
54

6-
from fastapi import HTTPException
75
from loguru import logger
86
from openai.types.chat import (
97
ChatCompletionMessageParam,
@@ -12,7 +10,6 @@
1210
)
1311
from transformers import PreTrainedTokenizer
1412

15-
from api.generation.utils import parse_messages
1613
from api.utils.protocol import Role
1714

1815
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
@@ -40,8 +37,7 @@
4037
def build_qwen_chat_input(
4138
tokenizer: PreTrainedTokenizer,
4239
messages: List[ChatCompletionMessageParam],
43-
context_len: int = 8192,
44-
max_new_tokens: int = 256,
40+
max_window_size: int = 6144,
4541
functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
4642
tools: Optional[List[Dict[str, Any]]] = None,
4743
) -> List[int]:
@@ -54,71 +50,65 @@ def build_qwen_chat_input(
5450
Args:
5551
tokenizer: The tokenizer used to encode the input tokens.
5652
messages: The list of chat messages.
57-
context_len: The maximum length of the context.
58-
max_new_tokens: The maximum number of new tokens to add.
53+
max_window_size: The maximum length of the context.
5954
functions: Optional dictionary or list of dictionaries representing the functions.
6055
tools: Optional list of dictionaries representing the tools.
6156
6257
Returns:
6358
The list of input tokens.
6459
"""
65-
query, history = process_qwen_messages(messages, functions, tools)
60+
query, history, system = process_qwen_messages(messages, functions, tools)
6661
if query is _TEXT_COMPLETION_CMD:
67-
return build_last_message_input(tokenizer, history)
68-
69-
messages = []
70-
for q, r in history:
71-
messages.extend(
72-
[
73-
ChatCompletionUserMessageParam(role="user", content=q),
74-
ChatCompletionAssistantMessageParam(role="assistant", content=r)
75-
]
76-
)
77-
messages.append(ChatCompletionUserMessageParam(role="user", content=query))
78-
79-
max_input_tokens = context_len - max_new_tokens
80-
system, rounds = parse_messages(messages)
81-
system = f"You are a helpful assistant.{system}"
62+
return build_last_message_input(tokenizer, history, system)
8263

8364
im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id]
8465
nl_tokens = tokenizer.encode("\n")
8566

86-
def _tokenize_str(role, content):
87-
return tokenizer.encode(
88-
role, allowed_special=set()
89-
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
67+
if hasattr(tokenizer, "IMAGE_ST"):
68+
def _tokenize_str(role, content):
69+
return tokenizer.encode(
70+
role, allowed_special=set(tokenizer.IMAGE_ST)
71+
) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
72+
else:
73+
def _tokenize_str(role, content):
74+
return tokenizer.encode(
75+
role, allowed_special=set()
76+
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
9077

9178
system_tokens_part = _tokenize_str("system", system)
9279
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
93-
max_history_tokens = max_input_tokens - len(system_tokens)
9480

95-
history_tokens = []
96-
for r in rounds[::-1]:
97-
round_tokens = []
98-
for message in r:
99-
if round_tokens:
100-
round_tokens += nl_tokens
81+
context_tokens = []
82+
for turn_query, turn_response in reversed(history):
83+
query_tokens_part = _tokenize_str("user", turn_query)
84+
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
10185

102-
if message["role"] == Role.USER:
103-
content_tokens = im_start_tokens + _tokenize_str("user", message["content"]) + im_end_tokens
104-
else:
105-
content_tokens = im_start_tokens + _tokenize_str("assistant", message["content"]) + im_end_tokens
86+
response_tokens_part = _tokenize_str("assistant", turn_response)
87+
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
10688

107-
round_tokens.extend(content_tokens)
89+
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
10890

109-
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
110-
if history_tokens:
111-
history_tokens = nl_tokens + history_tokens
91+
current_context_size = (
92+
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
93+
)
94+
if current_context_size < max_window_size:
95+
context_tokens = next_context_tokens + context_tokens
96+
else:
97+
break
11298

113-
history_tokens = round_tokens + history_tokens # concat left
114-
if len(history_tokens) < max_history_tokens:
115-
continue
116-
break
99+
context_tokens = system_tokens + context_tokens
100+
context_tokens += (
101+
nl_tokens
102+
+ im_start_tokens
103+
+ _tokenize_str("user", query)
104+
+ im_end_tokens
105+
+ nl_tokens
106+
+ im_start_tokens
107+
+ tokenizer.encode("assistant")
108+
+ nl_tokens
109+
)
117110

118-
input_tokens = system_tokens + nl_tokens + history_tokens
119-
if messages[-1]["role"] != Role.ASSISTANT:
120-
input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens
121-
return input_tokens[-max_input_tokens:] # truncate left
111+
return context_tokens
122112

123113

124114
def check_is_qwen(model) -> bool:
@@ -138,7 +128,7 @@ def process_qwen_messages(
138128
messages: List[ChatCompletionMessageParam],
139129
functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
140130
tools: Optional[List[Dict[str, Any]]] = None,
141-
) -> Tuple[str, List[List[str]]]:
131+
) -> Tuple[str, List[List[str]], str]:
142132
"""
143133
Process the Qwen messages and generate a query and history.
144134
@@ -148,21 +138,16 @@ def process_qwen_messages(
148138
tools (Optional[List[Dict[str, Any]]]): The tools to be used.
149139
150140
Returns:
151-
Tuple[str, List[List[str]]]: The generated query and history.
141+
Tuple[str, List[List[str]], str]: The generated query and history and system.
152142
"""
153143
if all(m["role"] != Role.USER for m in messages):
154-
raise HTTPException(
155-
status_code=400,
156-
detail=f"Invalid request: Expecting at least one user message.",
157-
)
144+
raise ValueError(f"Invalid messages: Expecting at least one user message.")
158145

159146
messages = deepcopy(messages)
160-
default_system = "You are a helpful assistant."
161-
system = ""
162147
if messages[0]["role"] == Role.SYSTEM:
163148
system = messages.pop(0)["content"].lstrip("\n").rstrip()
164-
if system == default_system:
165-
system = ""
149+
else:
150+
system = "You are a helpful assistant."
166151

167152
if tools:
168153
functions = [t["function"] for t in tools]
@@ -191,55 +176,37 @@ def process_qwen_messages(
191176

192177
tools_text = "\n\n".join(tools_text)
193178
tools_name_text = ", ".join(tools_name_text)
194-
system += "\n\n" + REACT_INSTRUCTION.format(
179+
instruction = REACT_INSTRUCTION.format(
195180
tools_text=tools_text,
196181
tools_name_text=tools_name_text,
197-
)
198-
system = system.lstrip("\n").rstrip()
182+
).lstrip('\n').rstrip()
183+
else:
184+
instruction = ""
199185

200-
dummy_thought = {
201-
"en": "\nThought: I now know the final answer.\nFinal answer: ",
202-
"zh": "\nThought: 我会作答了。\nFinal answer: ",
203-
}
204-
205-
_messages = messages
186+
messages_with_fncall = messages
206187
messages = []
207-
for m_idx, m in enumerate(_messages):
188+
for m_idx, m in enumerate(messages_with_fncall):
208189
role, content = m["role"], m["content"]
209190
func_call, tool_calls = m.get("function_call", None), m.get("tool_calls", None)
210-
if content:
211-
content = content.lstrip("\n").rstrip()
191+
192+
content = content or ''
193+
content = content.lstrip('\n').rstrip()
194+
212195
if role in [Role.FUNCTION, Role.TOOL]:
213196
if (len(messages) == 0) or (messages[-1]["role"] != Role.ASSISTANT):
214-
raise HTTPException(
215-
status_code=400,
216-
detail=f"Invalid request: Expecting role assistant before role function.",
217-
)
197+
raise ValueError(f"Invalid messages: Expecting role assistant before role function.")
198+
218199
messages[-1]["content"] += f"\nObservation: {content}"
219-
if m_idx == len(_messages) - 1:
200+
if m_idx == len(messages_with_fncall) - 1:
220201
messages[-1]["content"] += "\nThought:"
202+
221203
elif role == Role.ASSISTANT:
222204
if len(messages) == 0:
223-
raise HTTPException(
224-
status_code=400,
225-
detail=f"Invalid request: Expecting role user before role assistant.",
226-
)
227-
last_msg = messages[-1]["content"]
228-
last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
205+
raise ValueError(f"Invalid messages: Expecting role user before role assistant.")
229206

230207
if func_call is None and tool_calls is None:
231208
if functions or tool_calls:
232-
content = dummy_thought["zh" if last_msg_has_zh else "en"] + content
233-
else:
234-
if func_call:
235-
f_name, f_args = func_call.get("name"), func_call.get("arguments")
236-
else:
237-
f_name, f_args = tool_calls[0]["function"]["name"], tool_calls[0]["function"]["arguments"]
238-
if not content:
239-
if last_msg_has_zh:
240-
content = f"Thought: 我可以使用 {f_name} API。"
241-
else:
242-
content = f"Thought: I can use {f_name}."
209+
content = f"Thought: I now know the final answer.\nFinal Answer: {content}"
243210

244211
if messages[-1]["role"] == Role.USER:
245212
messages.append(
@@ -252,46 +219,39 @@ def process_qwen_messages(
252219
ChatCompletionUserMessageParam(role="user", content=content.lstrip("\n").rstrip())
253220
)
254221
else:
255-
raise HTTPException(
256-
status_code=400, detail=f"Invalid request: Incorrect role {role}."
257-
)
222+
raise ValueError(f"Invalid messages: Incorrect role {role}.")
258223

259224
query = _TEXT_COMPLETION_CMD
260225
if messages[-1]["role"] == Role.USER:
261226
query = messages[-1]["content"]
262227
messages = messages[:-1]
263228

264229
if len(messages) % 2 != 0:
265-
raise HTTPException(status_code=400, detail="Invalid request")
230+
raise ValueError("Invalid messages")
266231

267232
history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
268233
for i in range(0, len(messages), 2):
269234
if messages[i]["role"] == Role.USER and messages[i + 1]["role"] == Role.ASSISTANT:
270235
usr_msg = messages[i]["content"].lstrip("\n").rstrip()
271236
bot_msg = messages[i + 1]["content"].lstrip("\n").rstrip()
272-
if system and (i == len(messages) - 2):
273-
usr_msg = f"{system}\n\nQuestion: {usr_msg}"
274-
system = ""
275-
for t in dummy_thought.values():
276-
t = t.lstrip("\n")
277-
if bot_msg.startswith(t) and ("\nAction: " in bot_msg):
278-
bot_msg = bot_msg[len(t):]
237+
if instruction and (i == len(messages) - 2):
238+
usr_msg = f"{instruction}\n\nQuestion: {usr_msg}"
239+
instruction = ''
279240
history.append([usr_msg, bot_msg])
280241
else:
281-
raise HTTPException(
282-
status_code=400,
283-
detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.",
284-
)
285-
if system:
242+
raise ValueError("Invalid messages: Expecting exactly one user (or function) role before every assistant role.")
243+
244+
if instruction:
286245
assert query is not _TEXT_COMPLETION_CMD
287-
query = f"{system}\n\nQuestion: {query}"
288-
return query, history
246+
query = f"{instruction}\n\nQuestion: {query}"
247+
248+
return query, history, system
289249

290250

291-
def build_last_message_input(tokenizer: PreTrainedTokenizer, history: list):
251+
def build_last_message_input(tokenizer: PreTrainedTokenizer, history: List[List[str]], system: str):
292252
im_start = "<|im_start|>"
293253
im_end = "<|im_end|>"
294-
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
254+
prompt = f"{im_start}system\n{system}{im_end}"
295255
for i, (query, response) in enumerate(history):
296256
query = query.lstrip("\n").rstrip()
297257
response = response.lstrip("\n").rstrip()

0 commit comments

Comments
 (0)