Skip to content

Commit 4a8af77

Browse files
author
xusenlin
committed
Support for Text Generation Inference (TGI)
1 parent 69c55fd commit 4a8af77

File tree

20 files changed

+672
-40
lines changed

20 files changed

+672
-40
lines changed

api/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ class Settings(BaseModel):
239239
description="RoPE frequency scaling factor",
240240
)
241241

242+
# support for tgi
243+
tgi_endpoint: Optional[str] = Field(
244+
default=get_env("TGI_ENDPOINT", None),
245+
description="Text Generate Inference Endpoint.",
246+
)
247+
242248

243249
SETTINGS = Settings()
244250
logger.debug(f"SETTINGS: {model_json(SETTINGS, indent=4)}")

api/core/default.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def _create_chat_completion_stream(self, params: Dict[str, Any]) -> Iterator:
397397
finish_reason=None,
398398
)
399399
yield ChatCompletionChunk(
400-
id=_id,
400+
id=f"chat{_id}",
401401
choices=[choice],
402402
created=_created,
403403
model=_model,
@@ -443,7 +443,7 @@ def _create_chat_completion_stream(self, params: Dict[str, Any]) -> Iterator:
443443
finish_reason=finish_reason
444444
)
445445
yield ChatCompletionChunk(
446-
id=_id,
446+
id=f"chat{_id}",
447447
choices=[choice],
448448
created=_created,
449449
model=_model,
@@ -457,7 +457,7 @@ def _create_chat_completion_stream(self, params: Dict[str, Any]) -> Iterator:
457457
finish_reason="stop"
458458
)
459459
yield ChatCompletionChunk(
460-
id=_id,
460+
id=f"chat{_id}",
461461
choices=[choice],
462462
created=_created,
463463
model=_model,
@@ -521,7 +521,7 @@ def _create_chat_completion(self, params: Dict[str, Any]) -> Union[ChatCompletio
521521
)
522522
usage = model_parse(CompletionUsage, last_output["usage"])
523523
return ChatCompletion(
524-
id=last_output["id"],
524+
id=f"chat{last_output['id']}",
525525
choices=[choice],
526526
created=last_output["created"],
527527
model=last_output["model"],

api/core/llama_cpp_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _create_chat_completion_stream(self, prompt, **kwargs) -> Iterator:
129129
finish_reason=None,
130130
)
131131
yield ChatCompletionChunk(
132-
id=_id,
132+
id=f"chat{_id}",
133133
choices=[choice],
134134
created=_created,
135135
model=_model,
@@ -147,7 +147,7 @@ def _create_chat_completion_stream(self, prompt, **kwargs) -> Iterator:
147147
finish_reason=output["choices"][0]["finish_reason"],
148148
)
149149
yield ChatCompletionChunk(
150-
id=_id,
150+
id=f"chat{_id}",
151151
choices=[choice],
152152
created=_created,
153153
model=_model,

api/core/tgi.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import json
2+
from typing import Optional, List, AsyncIterator
3+
4+
from aiohttp import ClientSession
5+
from openai.types.chat import ChatCompletionMessageParam
6+
from pydantic import ValidationError
7+
from text_generation import AsyncClient
8+
from text_generation.errors import parse_error
9+
from text_generation.types import Request, Parameters
10+
from text_generation.types import Response, StreamResponse
11+
12+
from api.adapter import get_prompt_adapter
13+
from api.utils.compat import model_dump
14+
15+
16+
class TGIEngine:
17+
def __init__(
18+
self,
19+
model: AsyncClient,
20+
model_name: str,
21+
prompt_name: Optional[str] = None,
22+
):
23+
"""
24+
Initializes the TGIEngine object.
25+
26+
Args:
27+
model: The AsyncLLMEngine object.
28+
model_name: The name of the model.
29+
prompt_name: The name of the prompt (optional).
30+
"""
31+
self.model = model
32+
self.model_name = model_name.lower()
33+
self.prompt_name = prompt_name.lower() if prompt_name is not None else None
34+
self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name)
35+
36+
def apply_chat_template(
37+
self, messages: List[ChatCompletionMessageParam],
38+
) -> str:
39+
"""
40+
Applies a chat template to the given messages and returns the processed output.
41+
42+
Args:
43+
messages: A list of ChatCompletionMessageParam objects representing the chat messages.
44+
45+
Returns:
46+
str: The processed output as a string.
47+
"""
48+
return self.prompt_adapter.apply_chat_template(messages)
49+
50+
async def generate(
51+
self,
52+
prompt: str,
53+
do_sample: bool = True,
54+
max_new_tokens: int = 20,
55+
best_of: Optional[int] = None,
56+
repetition_penalty: Optional[float] = None,
57+
return_full_text: bool = False,
58+
seed: Optional[int] = None,
59+
stop_sequences: Optional[List[str]] = None,
60+
temperature: Optional[float] = None,
61+
top_k: Optional[int] = None,
62+
top_p: Optional[float] = None,
63+
truncate: Optional[int] = None,
64+
typical_p: Optional[float] = None,
65+
watermark: bool = False,
66+
decoder_input_details: bool = True,
67+
top_n_tokens: Optional[int] = None,
68+
) -> Response:
69+
"""
70+
Given a prompt, generate the following text asynchronously
71+
72+
Args:
73+
prompt (`str`):
74+
Input text
75+
do_sample (`bool`):
76+
Activate logits sampling
77+
max_new_tokens (`int`):
78+
Maximum number of generated tokens
79+
best_of (`int`):
80+
Generate best_of sequences and return the one if the highest token logprobs
81+
repetition_penalty (`float`):
82+
The parameter for repetition penalty. 1.0 means no penalty. See [this
83+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
84+
return_full_text (`bool`):
85+
Whether to prepend the prompt to the generated text
86+
seed (`int`):
87+
Random sampling seed
88+
stop_sequences (`List[str]`):
89+
Stop generating tokens if a member of `stop_sequences` is generated
90+
temperature (`float`):
91+
The value used to module the logits distribution.
92+
top_k (`int`):
93+
The number of the highest probability vocabulary tokens to keep for top-k-filtering.
94+
top_p (`float`):
95+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
96+
higher are kept for generation.
97+
truncate (`int`):
98+
Truncate inputs tokens to the given size
99+
typical_p (`float`):
100+
Typical Decoding mass
101+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
102+
watermark (`bool`):
103+
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
104+
decoder_input_details (`bool`):
105+
Return the decoder input token logprobs and ids
106+
top_n_tokens (`int`):
107+
Return the `n` most likely tokens at each step
108+
109+
Returns:
110+
Response: generated response
111+
"""
112+
# Validate parameters
113+
parameters = Parameters(
114+
best_of=best_of,
115+
details=True,
116+
decoder_input_details=decoder_input_details,
117+
do_sample=do_sample,
118+
max_new_tokens=max_new_tokens,
119+
repetition_penalty=repetition_penalty,
120+
return_full_text=return_full_text,
121+
seed=seed,
122+
stop=stop_sequences if stop_sequences is not None else [],
123+
temperature=temperature,
124+
top_k=top_k,
125+
top_p=top_p,
126+
truncate=truncate,
127+
typical_p=typical_p,
128+
watermark=watermark,
129+
top_n_tokens=top_n_tokens,
130+
)
131+
request = Request(inputs=prompt, stream=False, parameters=parameters)
132+
133+
async with ClientSession(
134+
headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout
135+
) as session:
136+
async with session.post(f"{self.model.base_url}/generate", json=model_dump(request)) as resp:
137+
payload = await resp.json()
138+
139+
if resp.status != 200:
140+
raise parse_error(resp.status, payload)
141+
return Response(**payload)
142+
143+
async def generate_stream(
144+
self,
145+
prompt: str,
146+
do_sample: bool = False,
147+
max_new_tokens: int = 20,
148+
best_of: Optional[int] = 1,
149+
repetition_penalty: Optional[float] = None,
150+
return_full_text: bool = False,
151+
seed: Optional[int] = None,
152+
stop_sequences: Optional[List[str]] = None,
153+
temperature: Optional[float] = None,
154+
top_k: Optional[int] = None,
155+
top_p: Optional[float] = None,
156+
truncate: Optional[int] = None,
157+
typical_p: Optional[float] = None,
158+
watermark: bool = False,
159+
top_n_tokens: Optional[int] = None,
160+
) -> AsyncIterator[StreamResponse]:
161+
"""
162+
Given a prompt, generate the following stream of tokens asynchronously
163+
164+
Args:
165+
prompt (`str`):
166+
Input text
167+
do_sample (`bool`):
168+
Activate logits sampling
169+
max_new_tokens (`int`):
170+
Maximum number of generated tokens
171+
best_of (`int`):
172+
Generate best_of sequences and return the one if the highest token logprobs
173+
repetition_penalty (`float`):
174+
The parameter for repetition penalty. 1.0 means no penalty. See [this
175+
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
176+
return_full_text (`bool`):
177+
Whether to prepend the prompt to the generated text
178+
seed (`int`):
179+
Random sampling seed
180+
stop_sequences (`List[str]`):
181+
Stop generating tokens if a member of `stop_sequences` is generated
182+
temperature (`float`):
183+
The value used to module the logits distribution.
184+
top_k (`int`):
185+
The number of the highest probability vocabulary tokens to keep for top-k-filtering.
186+
top_p (`float`):
187+
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
188+
higher are kept for generation.
189+
truncate (`int`):
190+
Truncate inputs tokens to the given size
191+
typical_p (`float`):
192+
Typical Decoding mass
193+
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
194+
watermark (`bool`):
195+
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
196+
top_n_tokens (`int`):
197+
Return the `n` most likely tokens at each step
198+
199+
Returns:
200+
AsyncIterator: stream of generated tokens
201+
"""
202+
# Validate parameters
203+
parameters = Parameters(
204+
best_of=best_of,
205+
details=True,
206+
do_sample=do_sample,
207+
max_new_tokens=max_new_tokens,
208+
repetition_penalty=repetition_penalty,
209+
return_full_text=return_full_text,
210+
seed=seed,
211+
stop=stop_sequences if stop_sequences is not None else [],
212+
temperature=temperature,
213+
top_k=top_k,
214+
top_p=top_p,
215+
truncate=truncate,
216+
typical_p=typical_p,
217+
watermark=watermark,
218+
top_n_tokens=top_n_tokens,
219+
)
220+
request = Request(inputs=prompt, parameters=parameters)
221+
222+
async with ClientSession(
223+
headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout
224+
) as session:
225+
async with session.post(f"{self.model.base_url}/generate_stream", json=model_dump(request)) as resp:
226+
if resp.status != 200:
227+
raise parse_error(resp.status, await resp.json())
228+
229+
# Parse ServerSentEvents
230+
async for byte_payload in resp.content:
231+
# Skip line
232+
if byte_payload == b"\n":
233+
continue
234+
235+
payload = byte_payload.decode("utf-8")
236+
237+
# Event data
238+
if payload.startswith("data:"):
239+
# Decode payload
240+
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
241+
# Parse payload
242+
try:
243+
response = StreamResponse(**json_payload)
244+
except ValidationError:
245+
# If we failed to parse the payload, then it is an error payload
246+
raise parse_error(resp.status, json_payload)
247+
yield response
248+
249+
@property
250+
def stop(self):
251+
"""
252+
Gets the stop property of the prompt adapter.
253+
254+
Returns:
255+
The stop property of the prompt adapter, or None if it does not exist.
256+
"""
257+
return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None

api/core/vllm_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ def generate(self, params: Dict[str, Any], request_id: str) -> AsyncIterator:
140140
stop=params.get("stop", []),
141141
stop_token_ids=params.get("stop_token_ids", []),
142142
max_tokens=params.get("max_tokens", 256),
143+
repetition_penalty=params.get("repetition_penalty", 1.03),
144+
min_p=params.get("min_p", 0.0),
145+
best_of=params.get("best_of", 1),
146+
ignore_eos=params.get("ignore_eos", False),
147+
use_beam_search=params.get("use_beam_search", False),
148+
skip_special_tokens=params.get("skip_special_tokens", True),
149+
spaces_between_special_tokens=params.get("spaces_between_special_tokens", True),
143150
)
144151
result_generator = self.model.generate(
145152
prompt_or_messages if isinstance(prompt_or_messages, str) else None,

api/llama_cpp_routes/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def create_chat_completion(
3535

3636
prompt = engine.apply_chat_template(request.messages, request.functions, request.tools)
3737
include = {
38-
"temperature", "temperature", "top_p", "stream", "stop",
38+
"temperature", "top_p", "stream", "stop",
3939
"max_tokens", "presence_penalty", "frequency_penalty", "model"
4040
}
4141
kwargs = model_dump(request, include=include)

api/llama_cpp_routes/completion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ async def create_completion(
3030
request.prompt = request.prompt[0] if len(request.prompt) > 0 else ""
3131

3232
request.max_tokens = request.max_tokens or 256
33-
request, stop_token_ids = await handle_request(request, engine.stop)
33+
request = await handle_request(request, engine.stop)
3434

3535
include = {
36-
"temperature", "temperature", "top_p", "stream", "stop",
36+
"temperature", "top_p", "stream", "stop",
3737
"max_tokens", "presence_penalty", "frequency_penalty", "model"
3838
}
3939
kwargs = model_dump(request, include=include)

api/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ def create_llama_cpp_engine():
131131
return LlamaCppEngine(engine, SETTINGS.model_name, SETTINGS.chat_template)
132132

133133

134+
def create_tgi_engine():
135+
""" get llama.cpp generate engine for chat or completion. """
136+
try:
137+
from text_generation import AsyncClient
138+
from api.core.tgi import TGIEngine
139+
except ImportError:
140+
return None
141+
142+
client = AsyncClient(SETTINGS.tgi_endpoint)
143+
logger.info("Using TGI engine")
144+
145+
return TGIEngine(client, SETTINGS.model_name, SETTINGS.chat_template)
146+
147+
134148
# fastapi app
135149
app = create_app()
136150

@@ -145,6 +159,8 @@ def create_llama_cpp_engine():
145159
GENERATE_ENGINE = create_vllm_engine()
146160
elif SETTINGS.engine == "llama.cpp":
147161
GENERATE_ENGINE = create_llama_cpp_engine()
162+
elif SETTINGS.engine == "tgi":
163+
GENERATE_ENGINE = create_tgi_engine()
148164
else:
149165
GENERATE_ENGINE = None
150166

0 commit comments

Comments
 (0)