Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit d2cee03

Browse files
Fix the bug in rag example (Neural Chat) (#226)
* code revision Signed-off-by: XuhuiRen <xuhui.ren@intel.com> * revision Signed-off-by: XuhuiRen <xuhui.ren@intel.com> * revision Signed-off-by: XuhuiRen <xuhui.ren@intel.com> * revision Signed-off-by: XuhuiRen <xuhui.ren@intel.com> * revise Signed-off-by: XuhuiRen <xuhui.ren@intel.com> * Update __init__.py * Update document_parser.py * fixed import error. Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com> --------- Signed-off-by: XuhuiRen <xuhui.ren@intel.com> Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com> Co-authored-by: Ye, Xinyu <xinyu.ye@intel.com>
1 parent 81a651d commit d2cee03

File tree

7 files changed

+23
-29
lines changed

7 files changed

+23
-29
lines changed

intel_extension_for_transformers/neural_chat/chatbot.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@
3030
from .pipeline.plugins.audio.asr_chinese import ChineseAudioSpeechRecognition
3131
from .pipeline.plugins.audio.tts import TextToSpeech
3232
from .pipeline.plugins.audio.tts_chinese import ChineseTextToSpeech
33+
from .pipeline.plugins.security import SafetyChecker
3334
from .pipeline.plugins.retrievals import QA_Client
34-
from .pipeline.plugins.security.safety_checker import SafetyChecker
35-
from .pipeline.plugins.intent_detector import IntentDetector
3635
from .models.llama_model import LlamaModel
3736
from .models.mpt_model import MptModel
3837
from .models.chatglm_model import ChatGlmModel

intel_extension_for_transformers/neural_chat/cli/cli_commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List
2121
from ..utils.command import NeuralChatCommandDict
2222
from .base_executor import BaseCommandExecutor
23-
from ..config import PipelineConfig, FinetuningConfig, GenerationConfig # pylint: disable=E0611
23+
from ..config import PipelineConfig, TextGenerationFinetuningConfig, GenerationConfig
2424
from ..config import ModelArguments, DataArguments, FinetuningArguments
2525
from ..plugins import plugins
2626
from transformers import TrainingArguments
@@ -311,7 +311,7 @@ def execute(self, argv: List[str]) -> bool:
311311
training_args = TrainingArguments(output_dir="./output")
312312
finetune_args= FinetuningArguments()
313313

314-
self.finetuneCfg = FinetuningConfig(model_args, data_args, training_args, finetune_args)
314+
self.finetuneCfg = TextGenerationFinetuningConfig(model_args, data_args, training_args, finetune_args)
315315
try:
316316
res = self()
317317
print(res)

intel_extension_for_transformers/neural_chat/examples/retrieval/retrieval.py renamed to intel_extension_for_transformers/neural_chat/examples/retrieval/retrieval_chat.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,23 @@
1818
import os
1919
import sys
2020
from transformers import TrainingArguments, HfArgumentParser
21-
from intel_extension_for_transformers.neural_chat.config import (
22-
PipelineConfig,
23-
RetrieverConfig,
24-
SafetyConfig,
25-
GenerationConfig
26-
)
21+
from intel_extension_for_transformers.neural_chat.config import PipelineConfig
2722
from intel_extension_for_transformers.neural_chat.chatbot import build_chatbot
2823

2924

3025
def main():
3126
# See all possible arguments in config.py
3227
# or by passing the --help flag to this script.
3328
# We now keep distinct sets of args, for a cleaner separation of concerns.
34-
parser = HfArgumentParser(
35-
(PipelineConfig, RetrieverConfig, SafetyConfig, GenerationConfig)
36-
)
29+
parser = HfArgumentParser(PipelineConfig)
3730

3831
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
3932
# If we pass only one argument to the script and it's the path to a json file,
4033
# let's parse it to get our arguments.
41-
pipeline_args, retriever_args, safety_args, generation_args = parser.parse_json_file(
42-
json_file = os.path.abspath(sys.argv[1])
43-
)
34+
pipeline_args= parser.parse_json_file(json_file = os.path.abspath(sys.argv[1]))
4435
else:
45-
(pipeline_args, retriever_args, safety_args, generation_args) = parser.parse_args_into_dataclasses()
36+
pipeline_args= parser.parse_args_into_dataclasses()
4637

47-
pipeline_args.saftey_config = safety_args
48-
pipeline_args.retrieval_config = retriever_args
49-
pipeline_args.generation_config = generation_args
5038
chatbot = build_chatbot(pipeline_args)
5139

5240
response = chatbot.predict(query="What is IDM 2.0?", config=pipeline_args)

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrievals/indexing/document_parser.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,17 @@ def KB_construct(self, input):
9797

9898
documents = []
9999
for data, meta in data_collection:
100+
if len(data) < 5:
101+
continue
100102
metadata = {"source": meta}
101103
new_doc = Document(page_content=data, metadata=metadata)
102104
documents.append(new_doc)
105+
assert documents!= [], "The given file/files cannot be loaded."
103106
embedding = HuggingFaceInstructEmbeddings(model_name=self.embedding_model)
104107
vectordb = Chroma.from_documents(documents=documents, embedding=embedding,
105108
persist_directory=self.persist_dir)
106109
vectordb.persist()
107-
print("success")
110+
print("The local knowledge base has been successfully built!")
108111
return vectordb
109112
else:
110113
print("There might be some errors, please wait and try again!")
@@ -125,11 +128,13 @@ def KB_construct(self, input):
125128
documents = []
126129
for data, meta in data_collection:
127130
metadata = {"source": meta}
128-
# pylint: disable=E1123
129-
new_doc = SDocument(content=data, metadata=metadata)
131+
if len(data) < 5:
132+
continue
133+
new_doc = SDocument(content=data, meta=metadata)
130134
documents.append(new_doc)
135+
assert documents != [], "The given file/files cannot be loaded."
131136
document_store.write_documents(documents)
132-
print("success")
137+
print("The local knowledge base has been successfully built!")
133138
return document_store
134139
else:
135140
print("There might be some errors, please wait and try again!")

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrievals/indexing/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,21 @@ def laod_structured_data(input, process, max_length):
165165

166166
def get_chuck_data(content, max_length, input):
167167
"""Process the context to make it maintain a suitable length for the generation."""
168-
sentences = re.split('(?<=[;!.?])', content)
168+
sentences = re.split('(?<=[!.?])', content)
169169

170170
paragraphs = []
171171
current_length = 0
172172
count = 0
173173
current_paragraph = ""
174174
for sub_sen in sentences:
175+
if sub_sen == "":
176+
continue
175177
count +=1
176178
sentence_length = len(sub_sen)
177179
if current_length + sentence_length <= max_length:
178180
current_paragraph += sub_sen
179181
current_length += sentence_length
180-
if count == len(sentences):
182+
if count == len(sentences) and len(current_paragraph.strip())>5:
181183
paragraphs.append([current_paragraph.strip() ,input])
182184
else:
183185
paragraphs.append([current_paragraph.strip() ,input])

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrievals/retrieval/bm25_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SparseBM25Retriever():
2222

2323
def __int__(self, document_store = None, top_k = 1):
2424
assert document_store is not None, "Please give a document database for retrieving."
25-
self.retriever = BM25Retriever(document_store=document_store)
25+
self.retriever = BM25Retriever(document_store=document_store, top_k=top_k)
2626

2727
def query_the_database(self, query):
2828
documents = self.retriever.retrieve(query)

intel_extension_for_transformers/neural_chat/server/restful/finetune_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ModelArguments,
2727
DataArguments,
2828
FinetuningArguments,
29-
FinetuningConfig,
29+
TextGenerationFinetuningConfig,
3030
)
3131
from intel_extension_for_transformers.neural_chat.server.restful.request import FinetuneRequest
3232

@@ -66,7 +66,7 @@ def handle_finetune_request(self, request: FinetuneRequest) -> str:
6666
overwrite_output_dir=request.overwrite_output_dir
6767
)
6868
finetune_args = FinetuningArguments(peft=request.peft)
69-
finetune_cfg = FinetuningConfig(
69+
finetune_cfg = TextGenerationFinetuningConfig(
7070
model_args=model_args,
7171
data_args=data_args,
7272
training_args=training_args,

0 commit comments

Comments
 (0)