Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit e95fc32

Browse files
lvliang-intelhshen14lkk12014402mengniwang95
authored
Update NeuralChat plugins API (#156)
* Update NeuralChat plugins API Signed-off-by: lvliang-intel <liang1.lv@intel.com> * update safety checker and intent detection plugin Signed-off-by: lvliang-intel <liang1.lv@intel.com> * add finetuning lm_eval metric (#147) * add lm_eval metric. * support summarization task evaluation. * update doc. Signed-off-by: Lv, Kaokao <kaokao.lv@intel.com> * update doc. Signed-off-by: Lv, Kaokao <kaokao.lv@intel.com> * Update README.md * Update README.md * add summarization function. * add evaluation desc. Signed-off-by: Lv, Kaokao <kaokao.lv@intel.com> --------- Signed-off-by: Lv, Kaokao <kaokao.lv@intel.com> Co-authored-by: Haihao Shen <haihao.shen@intel.com> * remove hard print (#158) Signed-off-by: Lv, Kaokao <kaokao.lv@intel.com> * Support PEFT model (#153) Signed-off-by: Mengni Wang <mengni.wang@intel.com> Co-authored-by: Haihao Shen <haihao.shen@intel.com> * update code for comments Signed-off-by: lvliang-intel <liang1.lv@intel.com> --------- Signed-off-by: lvliang-intel <liang1.lv@intel.com> Signed-off-by: Lv, Kaokao <kaokao.lv@intel.com> Signed-off-by: Mengni Wang <mengni.wang@intel.com> Co-authored-by: Haihao Shen <haihao.shen@intel.com> Co-authored-by: lkk <33276950+lkk12014402@users.noreply.github.com> Co-authored-by: Wang, Mengni <mengni.wang@intel.com>
1 parent 1380d5e commit e95fc32

File tree

17 files changed

+403
-201
lines changed

17 files changed

+403
-201
lines changed

neural_chat/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
from .chatbot import optimize_model
2525
from .server.neuralchat_server import NeuralChatServerExecutor
2626
from .server.neuralchat_client import TextChatClientExecutor, VoiceChatClientExecutor, FinetuingClientExecutor
27+
from .plugins import plugins
2728

neural_chat/chatbot.py

Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,22 @@
2020
from .config import PipelineConfig
2121
from .config import OptimizationConfig
2222
from .config import FinetuningConfig
23+
from .plugins import is_plugin_enabled, get_plugin_instance, get_registered_plugins
2324
from .pipeline.finetuning.finetuning import Finetuning
2425
from .pipeline.optimization.optimization import Optimization
25-
from .config import DeviceOptions, AudioLanguageOptions, RetrievalTypeOptions
26+
from .config import DeviceOptions
2627
from .models.base_model import get_model_adapter
2728
from .utils.common import get_device_type
2829
from .pipeline.plugins.caching.cache import init_similar_cache_from_config
2930
from .pipeline.plugins.audio.asr import AudioSpeechRecognition
3031
from .pipeline.plugins.audio.asr_chinese import ChineseAudioSpeechRecognition
3132
from .pipeline.plugins.audio.tts import TextToSpeech
32-
from .pipeline.plugins.audio.tts_chinese_tts import ChineseTextToSpeech
33+
from .pipeline.plugins.audio.tts_chinese import ChineseTextToSpeech
3334
from .pipeline.plugins.retrievers.indexing.document_parser import DocumentIndexing
3435
from .pipeline.plugins.retrievers.retriever.langchain import ChromaRetriever
3536
from .pipeline.plugins.retrievers.retriever import BM25Retriever
36-
from .pipeline.plugins.security.sensitive_checker import SensitiveChecker
37+
from .pipeline.plugins.security.safety_checker import SafetyChecker
38+
from .pipeline.plugins.intent_detector import IntentDetector
3739
from .models.llama_model import LlamaModel
3840
from .models.mpt_model import MptModel
3941
from .models.chatglm_model import ChatGlmModel
@@ -66,57 +68,12 @@ def build_chatbot(config: PipelineConfig=None):
6668
# get model adapter
6769
adapter = get_model_adapter(config.model_name_or_path)
6870

69-
# construct document retrieval using retrieval plugin
70-
if config.retrieval:
71-
if config.retrieval_type not in [option.name.lower() for option in RetrievalTypeOptions]:
72-
valid_options = ", ".join([option.name.lower() for option in RetrievalTypeOptions])
73-
raise ValueError(f"Invalid retrieval type value '{config.retrieval_type}'. Must be one of {valid_options}")
74-
if not config.retrieval_document_path:
75-
raise ValueError("Must provide a retrieval document path")
76-
if not os.path.exists(config.retrieval_document_path):
77-
raise ValueError(f"The retrieval document path {config.retrieval_document_path} is not exist.")
78-
db = DocumentIndexing(config.retrieval_type).KB_construct(config.retrieval_document_path)
79-
if config.retrieval_type == "dense":
80-
retriever = ChromaRetriever(db).retriever
81-
else:
82-
retriever = BM25Retriever(document_store = db)
83-
adapter.register_retriever(retriever, config.retrieval_type)
84-
85-
# construct audio plugin
86-
if config.audio_input or config.audio_output:
87-
if config.audio_lang not in [option.name.lower() for option in AudioLanguageOptions]:
88-
valid_options = ", ".join([option.name.lower() for option in AudioLanguageOptions])
89-
raise ValueError(f"Invalid audio language value '{config.audio_lang}'. Must be one of {valid_options}")
90-
if config.audio_input:
91-
if config.audio_lang == AudioLanguageOptions.CHINESE.name.lower():
92-
asr = ChineseAudioSpeechRecognition()
93-
else:
94-
asr = AudioSpeechRecognition()
95-
adapter.register_asr(asr)
96-
if config.audio_output:
97-
if config.audio_lang == AudioLanguageOptions.CHINESE.name.lower():
98-
tts = ChineseTextToSpeech()
99-
else:
100-
tts = TextToSpeech()
101-
adapter.register_tts(tts)
102-
103-
# construct response caching
104-
if config.cache_chat:
105-
if not config.cache_chat_config_file:
106-
cache_chat_config_file = "./pipeline/plugins/caching/cache_config.yaml"
107-
else:
108-
cache_chat_config_file = config.cache_chat_config_file
109-
if not config.cache_embedding_model_dir:
110-
cache_embedding_model_dir = "hkunlp/instructor-large"
111-
else:
112-
cache_embedding_model_dir = config.cache_embedding_model_dir
113-
init_similar_cache_from_config(config_dir=cache_chat_config_file,
114-
embedding_model_dir=cache_embedding_model_dir)
115-
116-
# construct safety checker
117-
if config.safety_checker:
118-
safety_checker = SensitiveChecker()
119-
adapter.register_safety_checker(safety_checker)
71+
# register plugin instance in model adaptor
72+
for plugin_name in get_registered_plugins():
73+
if is_plugin_enabled(plugin_name):
74+
plugin_instance = get_plugin_instance(plugin_name)
75+
if plugin_instance:
76+
adapter.register_plugin_instance(plugin_name, plugin_instance)
12077

12178
parameters = {}
12279
parameters["model_name"] = config.model_name_or_path
@@ -133,6 +90,7 @@ def build_chatbot(config: PipelineConfig=None):
13390
parameters["dtype"] = config.optimization_config.amp_config.dtype
13491
parameters["optimization_config"] = config.optimization_config
13592
adapter.load_model(parameters)
93+
13694
return adapter
13795

13896
def finetune_model(config: FinetuningConfig):

neural_chat/cli/cli_commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from neural_chat.pipeline.plugins.audio.asr import AudioSpeechRecognition
2626
from neural_chat.pipeline.plugins.audio.asr_chinese import ChineseAudioSpeechRecognition
2727
from neural_chat.pipeline.plugins.audio.tts import TextToSpeech
28-
from neural_chat.pipeline.plugins.audio.tts_chinese_tts import ChineseTextToSpeech
28+
from neural_chat.pipeline.plugins.audio.tts_chinese import ChineseTextToSpeech
2929

3030
__all__ = ['BaseCommand', 'HelpCommand', 'TextChatExecutor', 'VoiceChatExecutor', 'FinetuingExecutor']
3131

neural_chat/config.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
from transformers.utils.versions import require_version
2323
from dataclasses import dataclass
2424

25+
from neural_chat.pipeline.plugins.audio.asr import AudioSpeechRecognition
26+
from neural_chat.pipeline.plugins.audio.asr_chinese import ChineseAudioSpeechRecognition
27+
from neural_chat.pipeline.plugins.audio.tts import TextToSpeech
28+
from neural_chat.pipeline.plugins.audio.tts_chinese import ChineseTextToSpeech
29+
from .plugins import plugins
30+
2531
from enum import Enum, auto
2632

2733
class DeviceOptions(Enum):
@@ -383,27 +389,25 @@ class IntentConfig:
383389
ipex_int8: bool = False
384390

385391

386-
@dataclass
387392
class PipelineConfig:
388-
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
389-
tokenizer_name_or_path: str = None
390-
device: str = "auto"
391-
retrieval: bool = False
392-
retrieval_type: str = "dense"
393-
retrieval_document_path: str = None
394-
retrieval_config: RetrieverConfig = RetrieverConfig()
395-
audio_input: bool = False
396-
audio_output: bool = False
397-
audio_lang: str = "english"
398-
txt2Image: bool = False
399-
cache_chat: bool = False
400-
cache_chat_config_file: str = None
401-
cache_embedding_model_dir: str = None
402-
intent_detection: bool = False
403-
intent_config: IntentConfig = IntentConfig()
404-
memory_controller: bool = False
405-
safety_checker: bool = False
406-
saftey_config: SafetyConfig = SafetyConfig()
407-
loading_config: LoadingModelConfig = LoadingModelConfig()
408-
optimization_config: OptimizationConfig = OptimizationConfig()
393+
def __init__(self,
394+
model_name_or_path="meta-llama/Llama-2-7b-hf",
395+
tokenizer_name_or_path=None,
396+
device="auto",
397+
plugins=plugins,
398+
loading_config=None,
399+
optimization_config=None):
400+
self.model_name_or_path = model_name_or_path
401+
self.tokenizer_name_or_path = tokenizer_name_or_path
402+
self.device = device
403+
self.plugins = plugins
404+
self.loading_config = loading_config if loading_config is not None else LoadingModelConfig()
405+
self.optimization_config = optimization_config if optimization_config is not None else OptimizationConfig()
406+
for plugin_name, plugin_value in self.plugins.items():
407+
if plugin_value['enable']:
408+
print(f"create {plugin_name} plugin instance...")
409+
print(f"plugin parameters: ", plugin_value['args'])
410+
plugins[plugin_name]["instance"] = plugin_value['class'](**plugin_value['args'])
411+
412+
409413

neural_chat/models/base_model.py

Lines changed: 58 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from abc import ABC, abstractmethod
18+
from abc import ABC
1919
from typing import List
2020
import os
2121
from fastchat.conversation import get_conv_template, Conversation
22-
from neural_chat.pipeline.inference.inference import load_model, predict, predict_stream
22+
from neural_chat.pipeline.inference.inference import load_model, predict, predict_stream, MODELS
2323
from neural_chat.config import GenerationConfig
24+
from neural_chat.plugins import is_plugin_enabled, get_plugin_instance, get_registered_plugins, get_plugin_arguments
2425
from neural_chat.utils.common import is_audio_file
2526
from neural_chat.pipeline.plugins.prompts.prompt import generate_qa_prompt, generate_prompt
2627

@@ -59,7 +60,7 @@ def construct_prompt(query, retriever, retrieval_type):
5960
return generate_qa_prompt(query, context)
6061
else:
6162
return generate_prompt(query)
62-
63+
6364

6465
class BaseModel(ABC):
6566
"""
@@ -79,6 +80,7 @@ def __init__(self):
7980
self.retrieval_type = None
8081
self.safety_checker = None
8182
self.intent_detection = False
83+
self.cache = None
8284

8385
def match(self, model_path: str):
8486
"""
@@ -148,34 +150,42 @@ def predict(self, query, config=None):
148150
if is_audio_file(query):
149151
if not os.path.exists(query):
150152
raise ValueError(f"The audio file path {query} is invalid.")
151-
if self.asr:
152-
query = self.asr.audio2text(query)
153-
else:
154-
raise ValueError(f"The query {query} is audio file but there is no ASR registered.")
153+
154+
# plugin pre actions
155+
for plugin_name in get_registered_plugins():
156+
if is_plugin_enabled(plugin_name):
157+
plugin_instance = get_plugin_instance(plugin_name)
158+
if plugin_instance:
159+
if hasattr(plugin_instance, 'pre_llm_inference_actions'):
160+
if plugin_name == "asr" and not is_audio_file(query):
161+
continue
162+
if plugin_name == "intent_detection":
163+
response = plugin_instance.pre_llm_inference_actions(query,
164+
MODELS[self.model_name]["model"], MODELS[self.model_name]["tokenizer"])
165+
else:
166+
response = plugin_instance.pre_llm_inference_actions(query)
167+
if plugin_name == "safety_checker" and response:
168+
return "Your query contains sensitive words, please try another query."
169+
elif plugin_name == "intent_detection":
170+
if 'qa' not in response.lower():
171+
query = generate_prompt(query)
172+
else:
173+
query = generate_qa_prompt(query)
174+
else:
175+
query = response
155176
assert query is not None, "Query cannot be None."
156177

157-
if self.intent_detection:
158-
intent = predict(**construct_parameters(query, self.model_name, config.intent_config))
159-
if 'qa' not in intent.lower():
160-
intent = "chitchat"
161-
query = generate_prompt(query)
162-
elif self.retriever:
163-
query = construct_prompt(query, self.retriever, self.retrieval_type)
164-
else:
165-
query = generate_qa_prompt(query)
166-
else:
167-
if self.retriever:
168-
query = construct_prompt(query, self.retriever, self.retrieval_type)
169-
170-
if self.safety_checker:
171-
assert self.safety_checker.sensitive_check(query) is False, "The input query contains sensitive words."
178+
# LLM inference
172179
response = predict(**construct_parameters(query, self.model_name, config))
173-
if self.safety_checker:
174-
if self.safety_checker.sensitive_check(response):
175-
response = self.safety_checker.sensitive_filter(response)
176-
if self.tts:
177-
self.tts.text2speech(response, config.audio_output_path)
178-
response = config.audio_output_path
180+
181+
# plugin post actions
182+
for plugin_name in get_registered_plugins():
183+
if is_plugin_enabled(plugin_name):
184+
plugin_instance = get_plugin_instance(plugin_name)
185+
if plugin_instance:
186+
if hasattr(plugin_instance, 'post_llm_inference_actions'):
187+
response = plugin_instance.post_llm_inference_actions(response)
188+
179189
return response
180190

181191
def chat_stream(self, query, config=None):
@@ -210,43 +220,29 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
210220
"""
211221
return get_conv_template("one_shot")
212222

213-
def register_tts(self, instance):
214-
"""
215-
Register a text-to-speech (TTS) instance.
216-
217-
Args:
218-
instance: An instance of a TTS module.
219-
"""
220-
self.tts = instance
221-
222-
def register_asr(self, instance):
223-
"""
224-
Register an automatic speech recognition (ASR) instance.
225-
226-
Args:
227-
instance: An instance of an ASR module.
228-
"""
229-
self.asr = instance
230-
231-
def register_safety_checker(self, instance):
232-
"""
233-
Register a safety checker instance.
234-
235-
Args:
236-
instance: An instance of a safety checker module.
237-
"""
238-
self.safety_checker = instance
239-
240-
def register_retriever(self, retriever, retrieval_type):
223+
def register_plugin_instance(self, plugin_name, instance):
241224
"""
242-
Register a database retriever.
225+
Register a plugin instance.
243226
244227
Args:
245-
instance: An instance of a retriever.
246-
retrieval_type: The type of the retrieval method.
228+
instance: An instance of a plugin.
247229
"""
248-
self.retriever = retriever
249-
self.retrieval_type = retrieval_type
230+
if plugin_name == "tts":
231+
self.tts = instance
232+
if plugin_name == "tts_chinese":
233+
self.tts_chinese = instance
234+
if plugin_name == "asr":
235+
self.asr = instance
236+
if plugin_name == "asr_chinese":
237+
self.asr_chinese = instance
238+
if plugin_name == "retrieval":
239+
self.retrieval = instance
240+
if plugin_name == "cache":
241+
self.cache = instance
242+
if plugin_name == "intent_detection":
243+
self.intent_detection = instance
244+
if plugin_name == "safety_checker":
245+
self.safety_checker = instance
250246

251247

252248
# A global registry for all model adapters
@@ -266,4 +262,4 @@ def get_model_adapter(model_name_path: str) -> BaseModel:
266262
if adapter.match(model_path_basename) and type(adapter) != BaseModel:
267263
return adapter
268264

269-
raise ValueError(f"No valid model adapter for {model_name_path}")
265+
raise ValueError(f"No valid model adapter for {model_name_path}")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .asr import AudioSpeechRecognition
22
from .asr_chinese import ChineseAudioSpeechRecognition
33
from .tts import TextToSpeech
4-
from .tts_chinese_tts import ChineseTextToSpeech
4+
from .tts_chinese import ChineseTextToSpeech

neural_chat/pipeline/plugins/audio/asr.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import contextlib
2323
from pydub import AudioSegment
2424

25+
from neural_chat.plugins import register_plugin
2526

26-
class AudioSpeechRecognition:
27+
@register_plugin('asr')
28+
class AudioSpeechRecognition():
2729
"""Convert audio to text."""
2830
def __init__(self, model_name_or_path="openai/whisper-small", bf16=False, device="cpu"):
2931
self.device = device
@@ -58,4 +60,8 @@ def audio2text(self, audio_path):
5860
predicted_ids = self.model.generate(inputs)
5961
result = self.processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
6062
print(f"generated text in {time.time() - start} seconds, and the result is: {result}")
61-
return result
63+
return result
64+
65+
66+
def pre_llm_inference_actions(self, audio_path):
67+
return self.audio2text(audio_path)

neural_chat/pipeline/plugins/audio/asr_chinese.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
from paddlespeech.cli.asr.infer import ASRExecutor
1919
import time
20+
from neural_chat.plugins import register_plugin
2021

21-
class ChineseAudioSpeechRecognition:
22+
@register_plugin('asr_chinese')
23+
class ChineseAudioSpeechRecognition():
2224
"""Convert audio to text in Chinese."""
2325
def __init__(self):
2426
self.asr = ASRExecutor()
@@ -31,3 +33,6 @@ def audio2text(self, audio_path):
3133
start = time.time()
3234
result = self.asr(audio_file=audio_path)
3335
print(f"generated text in {time.time() - start} seconds, and the result is: {result}")
36+
37+
def pre_llm_inference_actions(self, audio_path):
38+
return self.audio2text(audio_path)

0 commit comments

Comments
 (0)