Skip to content

Commit f2f2825

Browse files
author
xusenlin
committed
Support for text-embeddings-inference
1 parent 8f5e292 commit f2f2825

File tree

5 files changed

+93
-35
lines changed

5 files changed

+93
-35
lines changed

api/config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Settings(BaseModel):
3939
)
4040
engine: Optional[str] = Field(
4141
default=get_env("ENGINE", "default"),
42-
description="Choices are ['default', 'vllm', 'llama.cpp'].",
42+
description="Choices are ['default', 'vllm', 'llama.cpp', 'tgi'].",
4343
)
4444

4545
# model related
@@ -239,10 +239,24 @@ class Settings(BaseModel):
239239
description="RoPE frequency scaling factor",
240240
)
241241

242-
# support for tgi
242+
# support for tgi: https://github.com/huggingface/text-generation-inference
243243
tgi_endpoint: Optional[str] = Field(
244244
default=get_env("TGI_ENDPOINT", None),
245-
description="Text Generate Inference Endpoint.",
245+
description="Text Generation Inference Endpoint.",
246+
)
247+
248+
# support for tei: https://github.com/huggingface/text-embeddings-inference
249+
tei_endpoint: Optional[str] = Field(
250+
default=get_env("TEI_ENDPOINT", None),
251+
description="Text Embeddings Inference Endpoint.",
252+
)
253+
max_concurrent_requests: Optional[int] = Field(
254+
default=int(get_env("MAX_CONCURRENT_REQUESTS", 256)),
255+
description="The maximum amount of concurrent requests for this particular deployment."
256+
)
257+
max_client_batch_size: Optional[int] = Field(
258+
default=int(get_env("MAX_CLIENT_BATCH_SIZE", 32)),
259+
description="Control the maximum number of inputs that a client can send in a single request."
246260
)
247261

248262

api/models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ def create_app() -> FastAPI:
2121

2222
def create_embedding_model():
2323
""" get embedding model from sentence-transformers. """
24-
from sentence_transformers import SentenceTransformer
25-
26-
return SentenceTransformer(SETTINGS.embedding_name, device=SETTINGS.embedding_device)
24+
if SETTINGS.tei_endpoint is not None:
25+
from openai import AsyncOpenAI
26+
client = AsyncOpenAI(base_url=SETTINGS.tei_endpoint, api_key="none")
27+
else:
28+
from sentence_transformers import SentenceTransformer
29+
client = SentenceTransformer(SETTINGS.embedding_name, device=SETTINGS.embedding_device)
30+
return client
2731

2832

2933
def create_generate_model():

api/routes/embedding.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import asyncio
12
import base64
3+
from typing import Union
24

35
import numpy as np
46
import tiktoken
57
from fastapi import APIRouter, Depends
8+
from openai import AsyncOpenAI
69
from openai.types.create_embedding_response import Usage
710
from sentence_transformers import SentenceTransformer
811

@@ -23,7 +26,7 @@ def get_embedding_engine():
2326
async def create_embeddings(
2427
request: EmbeddingCreateParams,
2528
model_name: str = None,
26-
engine: SentenceTransformer = Depends(get_embedding_engine),
29+
client: Union[SentenceTransformer, AsyncOpenAI] = Depends(get_embedding_engine),
2730
):
2831
"""Creates embeddings for the text"""
2932
if request.model is None:
@@ -41,7 +44,7 @@ async def create_embeddings(
4144
request.input = [decoding.decode(text) for text in request.input]
4245

4346
# https://huggingface.co/BAAI/bge-large-zh
44-
if engine is not None and "bge" in SETTINGS.embedding_name.lower():
47+
if client is not None and "bge" in SETTINGS.embedding_name.lower():
4548
instruction = ""
4649
if "zh" in SETTINGS.embedding_name.lower():
4750
instruction = "为这个句子生成表示以用于检索相关文章:"
@@ -50,30 +53,65 @@ async def create_embeddings(
5053
request.input = [instruction + q for q in request.input]
5154

5255
data, total_tokens = [], 0
53-
batches = [
54-
request.input[i: i + 1024] for i in range(0, len(request.input), 1024)
55-
]
56-
for num_batch, batch in enumerate(batches):
57-
token_num = sum(len(i) for i in batch)
58-
vecs = engine.encode(batch, normalize_embeddings=True)
59-
60-
bs, dim = vecs.shape
61-
if SETTINGS.embedding_size > dim:
62-
zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
63-
vecs = np.c_[vecs, zeros]
64-
65-
if request.encoding_format == "base64":
66-
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
67-
else:
68-
vecs = vecs.tolist()
69-
70-
data.extend(
71-
Embedding(
72-
index=num_batch * 1024 + i, object="embedding", embedding=embed
56+
57+
# support for tei: https://github.com/huggingface/text-embeddings-inference
58+
if isinstance(client, AsyncOpenAI):
59+
global_batch_size = SETTINGS.max_concurrent_requests * SETTINGS.max_client_batch_size
60+
for i in range(0, len(request.input), global_batch_size):
61+
tasks = []
62+
texts = request.input[i: i + global_batch_size]
63+
for j in range(0, len(texts), SETTINGS.max_client_batch_size):
64+
tasks.append(
65+
client.embeddings.create(
66+
input=texts[j: j + SETTINGS.max_client_batch_size],
67+
model=request.model,
68+
)
69+
)
70+
res = await asyncio.gather(*tasks)
71+
72+
vecs = np.asarray([e.embedding for r in res for e in r.data])
73+
bs, dim = vecs.shape
74+
if SETTINGS.embedding_size > dim:
75+
zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
76+
vecs = np.c_[vecs, zeros]
77+
78+
if request.encoding_format == "base64":
79+
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
80+
else:
81+
vecs = vecs.tolist()
82+
83+
data.extend(
84+
Embedding(
85+
index=i * global_batch_size + j,
86+
object="embedding",
87+
embedding=embed
88+
)
89+
for j, embed in enumerate(vecs)
90+
)
91+
total_tokens += sum(r.usage.total_tokens for r in res)
92+
else:
93+
batches = [request.input[i: i + 1024] for i in range(0, len(request.input), 1024)]
94+
for num_batch, batch in enumerate(batches):
95+
token_num = sum(len(i) for i in batch)
96+
vecs = client.encode(batch, normalize_embeddings=True)
97+
98+
bs, dim = vecs.shape
99+
if SETTINGS.embedding_size > dim:
100+
zeros = np.zeros((bs, SETTINGS.embedding_size - dim))
101+
vecs = np.c_[vecs, zeros]
102+
103+
if request.encoding_format == "base64":
104+
vecs = [base64.b64encode(v.tobytes()).decode("utf-8") for v in vecs]
105+
else:
106+
vecs = vecs.tolist()
107+
108+
data.extend(
109+
Embedding(
110+
index=num_batch * 1024 + i, object="embedding", embedding=embed
111+
)
112+
for i, embed in enumerate(vecs)
73113
)
74-
for i, embed in enumerate(vecs)
75-
)
76-
total_tokens += token_num
114+
total_tokens += token_num
77115

78116
return CreateEmbeddingResponse(
79117
data=data,

api/routes/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class ModelList(BaseModel):
1919
available_models = ModelList(
2020
data=[
2121
Model(
22-
id=SETTINGS.model_name,
22+
id=SETTINGS.model_name or "",
2323
object="model",
2424
created=int(time.time()),
2525
owned_by="open"

api/server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from api.config import SETTINGS
22
from api.models import app, EMBEDDED_MODEL, GENERATE_ENGINE
3-
from api.routes import model_router
43

54

65
prefix = SETTINGS.api_prefix
7-
app.include_router(model_router, prefix=prefix, tags=["Model"])
86

97
if EMBEDDED_MODEL is not None:
108
from api.routes.embedding import embedding_router
@@ -13,6 +11,10 @@
1311

1412

1513
if GENERATE_ENGINE is not None:
14+
from api.routes import model_router
15+
16+
app.include_router(model_router, prefix=prefix, tags=["Model"])
17+
1618
if SETTINGS.engine == "vllm":
1719
from api.vllm_routes import chat_router as chat_router
1820
from api.vllm_routes import completion_router as completion_router
@@ -29,7 +31,7 @@
2931
from api.routes.chat import chat_router as chat_router
3032
from api.routes.completion import completion_router as completion_router
3133

32-
app.include_router(chat_router, prefix=prefix, tags=["Chat"])
34+
app.include_router(chat_router, prefix=prefix, tags=["Chat Completion"])
3335
app.include_router(completion_router, prefix=prefix, tags=["Completion"])
3436

3537

0 commit comments

Comments
 (0)