1515# See the License for the specific language governing permissions and
1616# limitations under the License.
1717
18- from abc import ABC , abstractmethod
18+ from abc import ABC
1919from typing import List
2020import os
2121from fastchat .conversation import get_conv_template , Conversation
22- from neural_chat .pipeline .inference .inference import load_model , predict , predict_stream
22+ from neural_chat .pipeline .inference .inference import load_model , predict , predict_stream , MODELS
2323from neural_chat .config import GenerationConfig
24+ from neural_chat .plugins import is_plugin_enabled , get_plugin_instance , get_registered_plugins , get_plugin_arguments
2425from neural_chat .utils .common import is_audio_file
2526from neural_chat .pipeline .plugins .prompts .prompt import generate_qa_prompt , generate_prompt
2627
@@ -59,7 +60,7 @@ def construct_prompt(query, retriever, retrieval_type):
5960 return generate_qa_prompt (query , context )
6061 else :
6162 return generate_prompt (query )
62-
63+
6364
6465class BaseModel (ABC ):
6566 """
@@ -79,6 +80,7 @@ def __init__(self):
7980 self .retrieval_type = None
8081 self .safety_checker = None
8182 self .intent_detection = False
83+ self .cache = None
8284
8385 def match (self , model_path : str ):
8486 """
@@ -148,34 +150,42 @@ def predict(self, query, config=None):
148150 if is_audio_file (query ):
149151 if not os .path .exists (query ):
150152 raise ValueError (f"The audio file path { query } is invalid." )
151- if self .asr :
152- query = self .asr .audio2text (query )
153- else :
154- raise ValueError (f"The query { query } is audio file but there is no ASR registered." )
153+
154+ # plugin pre actions
155+ for plugin_name in get_registered_plugins ():
156+ if is_plugin_enabled (plugin_name ):
157+ plugin_instance = get_plugin_instance (plugin_name )
158+ if plugin_instance :
159+ if hasattr (plugin_instance , 'pre_llm_inference_actions' ):
160+ if plugin_name == "asr" and not is_audio_file (query ):
161+ continue
162+ if plugin_name == "intent_detection" :
163+ response = plugin_instance .pre_llm_inference_actions (query ,
164+ MODELS [self .model_name ]["model" ], MODELS [self .model_name ]["tokenizer" ])
165+ else :
166+ response = plugin_instance .pre_llm_inference_actions (query )
167+ if plugin_name == "safety_checker" and response :
168+ return "Your query contains sensitive words, please try another query."
169+ elif plugin_name == "intent_detection" :
170+ if 'qa' not in response .lower ():
171+ query = generate_prompt (query )
172+ else :
173+ query = generate_qa_prompt (query )
174+ else :
175+ query = response
155176 assert query is not None , "Query cannot be None."
156177
157- if self .intent_detection :
158- intent = predict (** construct_parameters (query , self .model_name , config .intent_config ))
159- if 'qa' not in intent .lower ():
160- intent = "chitchat"
161- query = generate_prompt (query )
162- elif self .retriever :
163- query = construct_prompt (query , self .retriever , self .retrieval_type )
164- else :
165- query = generate_qa_prompt (query )
166- else :
167- if self .retriever :
168- query = construct_prompt (query , self .retriever , self .retrieval_type )
169-
170- if self .safety_checker :
171- assert self .safety_checker .sensitive_check (query ) is False , "The input query contains sensitive words."
178+ # LLM inference
172179 response = predict (** construct_parameters (query , self .model_name , config ))
173- if self .safety_checker :
174- if self .safety_checker .sensitive_check (response ):
175- response = self .safety_checker .sensitive_filter (response )
176- if self .tts :
177- self .tts .text2speech (response , config .audio_output_path )
178- response = config .audio_output_path
180+
181+ # plugin post actions
182+ for plugin_name in get_registered_plugins ():
183+ if is_plugin_enabled (plugin_name ):
184+ plugin_instance = get_plugin_instance (plugin_name )
185+ if plugin_instance :
186+ if hasattr (plugin_instance , 'post_llm_inference_actions' ):
187+ response = plugin_instance .post_llm_inference_actions (response )
188+
179189 return response
180190
181191 def chat_stream (self , query , config = None ):
@@ -210,43 +220,29 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
210220 """
211221 return get_conv_template ("one_shot" )
212222
213- def register_tts (self , instance ):
214- """
215- Register a text-to-speech (TTS) instance.
216-
217- Args:
218- instance: An instance of a TTS module.
219- """
220- self .tts = instance
221-
222- def register_asr (self , instance ):
223- """
224- Register an automatic speech recognition (ASR) instance.
225-
226- Args:
227- instance: An instance of an ASR module.
228- """
229- self .asr = instance
230-
231- def register_safety_checker (self , instance ):
232- """
233- Register a safety checker instance.
234-
235- Args:
236- instance: An instance of a safety checker module.
237- """
238- self .safety_checker = instance
239-
240- def register_retriever (self , retriever , retrieval_type ):
223+ def register_plugin_instance (self , plugin_name , instance ):
241224 """
242- Register a database retriever .
225+ Register a plugin instance .
243226
244227 Args:
245- instance: An instance of a retriever.
246- retrieval_type: The type of the retrieval method.
228+ instance: An instance of a plugin.
247229 """
248- self .retriever = retriever
249- self .retrieval_type = retrieval_type
230+ if plugin_name == "tts" :
231+ self .tts = instance
232+ if plugin_name == "tts_chinese" :
233+ self .tts_chinese = instance
234+ if plugin_name == "asr" :
235+ self .asr = instance
236+ if plugin_name == "asr_chinese" :
237+ self .asr_chinese = instance
238+ if plugin_name == "retrieval" :
239+ self .retrieval = instance
240+ if plugin_name == "cache" :
241+ self .cache = instance
242+ if plugin_name == "intent_detection" :
243+ self .intent_detection = instance
244+ if plugin_name == "safety_checker" :
245+ self .safety_checker = instance
250246
251247
252248# A global registry for all model adapters
@@ -266,4 +262,4 @@ def get_model_adapter(model_name_path: str) -> BaseModel:
266262 if adapter .match (model_path_basename ) and type (adapter ) != BaseModel :
267263 return adapter
268264
269- raise ValueError (f"No valid model adapter for { model_name_path } " )
265+ raise ValueError (f"No valid model adapter for { model_name_path } " )
0 commit comments