1+ import  asyncio 
12import  base64 
3+ from  typing  import  Union 
24
35import  numpy  as  np 
46import  tiktoken 
57from  fastapi  import  APIRouter , Depends 
8+ from  openai  import  AsyncOpenAI 
69from  openai .types .create_embedding_response  import  Usage 
710from  sentence_transformers  import  SentenceTransformer 
811
@@ -23,7 +26,7 @@ def get_embedding_engine():
2326async  def  create_embeddings (
2427 request : EmbeddingCreateParams ,
2528 model_name : str  =  None ,
26-  engine :  SentenceTransformer  =  Depends (get_embedding_engine ),
29+  client :  Union [ SentenceTransformer ,  AsyncOpenAI ]  =  Depends (get_embedding_engine ),
2730):
2831 """Creates embeddings for the text""" 
2932 if  request .model  is  None :
@@ -41,7 +44,7 @@ async def create_embeddings(
4144 request .input  =  [decoding .decode (text ) for  text  in  request .input ]
4245
4346 # https://huggingface.co/BAAI/bge-large-zh 
44-  if  engine  is  not   None  and  "bge"  in  SETTINGS .embedding_name .lower ():
47+  if  client  is  not   None  and  "bge"  in  SETTINGS .embedding_name .lower ():
4548 instruction  =  "" 
4649 if  "zh"  in  SETTINGS .embedding_name .lower ():
4750 instruction  =  "为这个句子生成表示以用于检索相关文章:" 
@@ -50,30 +53,65 @@ async def create_embeddings(
5053 request .input  =  [instruction  +  q  for  q  in  request .input ]
5154
5255 data , total_tokens  =  [], 0 
53-  batches  =  [
54-  request .input [i : i  +  1024 ] for  i  in  range (0 , len (request .input ), 1024 )
55-  ]
56-  for  num_batch , batch  in  enumerate (batches ):
57-  token_num  =  sum (len (i ) for  i  in  batch )
58-  vecs  =  engine .encode (batch , normalize_embeddings = True )
59- 
60-  bs , dim  =  vecs .shape 
61-  if  SETTINGS .embedding_size  >  dim :
62-  zeros  =  np .zeros ((bs , SETTINGS .embedding_size  -  dim ))
63-  vecs  =  np .c_ [vecs , zeros ]
64- 
65-  if  request .encoding_format  ==  "base64" :
66-  vecs  =  [base64 .b64encode (v .tobytes ()).decode ("utf-8" ) for  v  in  vecs ]
67-  else :
68-  vecs  =  vecs .tolist ()
69- 
70-  data .extend (
71-  Embedding (
72-  index = num_batch  *  1024  +  i , object = "embedding" , embedding = embed 
56+ 
57+  # support for tei: https://github.com/huggingface/text-embeddings-inference 
58+  if  isinstance (client , AsyncOpenAI ):
59+  global_batch_size  =  SETTINGS .max_concurrent_requests  *  SETTINGS .max_client_batch_size 
60+  for  i  in  range (0 , len (request .input ), global_batch_size ):
61+  tasks  =  []
62+  texts  =  request .input [i : i  +  global_batch_size ]
63+  for  j  in  range (0 , len (texts ), SETTINGS .max_client_batch_size ):
64+  tasks .append (
65+  client .embeddings .create (
66+  input = texts [j : j  +  SETTINGS .max_client_batch_size ],
67+  model = request .model ,
68+  )
69+  )
70+  res  =  await  asyncio .gather (* tasks )
71+ 
72+  vecs  =  np .asarray ([e .embedding  for  r  in  res  for  e  in  r .data ])
73+  bs , dim  =  vecs .shape 
74+  if  SETTINGS .embedding_size  >  dim :
75+  zeros  =  np .zeros ((bs , SETTINGS .embedding_size  -  dim ))
76+  vecs  =  np .c_ [vecs , zeros ]
77+ 
78+  if  request .encoding_format  ==  "base64" :
79+  vecs  =  [base64 .b64encode (v .tobytes ()).decode ("utf-8" ) for  v  in  vecs ]
80+  else :
81+  vecs  =  vecs .tolist ()
82+ 
83+  data .extend (
84+  Embedding (
85+  index = i  *  global_batch_size  +  j ,
86+  object = "embedding" ,
87+  embedding = embed 
88+  )
89+  for  j , embed  in  enumerate (vecs )
90+  )
91+  total_tokens  +=  sum (r .usage .total_tokens  for  r  in  res )
92+  else :
93+  batches  =  [request .input [i : i  +  1024 ] for  i  in  range (0 , len (request .input ), 1024 )]
94+  for  num_batch , batch  in  enumerate (batches ):
95+  token_num  =  sum (len (i ) for  i  in  batch )
96+  vecs  =  client .encode (batch , normalize_embeddings = True )
97+ 
98+  bs , dim  =  vecs .shape 
99+  if  SETTINGS .embedding_size  >  dim :
100+  zeros  =  np .zeros ((bs , SETTINGS .embedding_size  -  dim ))
101+  vecs  =  np .c_ [vecs , zeros ]
102+ 
103+  if  request .encoding_format  ==  "base64" :
104+  vecs  =  [base64 .b64encode (v .tobytes ()).decode ("utf-8" ) for  v  in  vecs ]
105+  else :
106+  vecs  =  vecs .tolist ()
107+ 
108+  data .extend (
109+  Embedding (
110+  index = num_batch  *  1024  +  i , object = "embedding" , embedding = embed 
111+  )
112+  for  i , embed  in  enumerate (vecs )
73113 )
74-  for  i , embed  in  enumerate (vecs )
75-  )
76-  total_tokens  +=  token_num 
114+  total_tokens  +=  token_num 
77115
78116 return  CreateEmbeddingResponse (
79117 data = data ,
0 commit comments