Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit daff796

Browse files
Add neuralchat server unit test cases (#276)
1 parent 0991c17 commit daff796

File tree

17 files changed

+440
-106
lines changed

17 files changed

+440
-106
lines changed

.github/workflows/script/unitTest/run_unit_test_neuralchat.sh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ function pytest() {
2828

2929
itrex_path=$(python -c 'import intel_extension_for_transformers; import os; print(os.path.dirname(intel_extension_for_transformers.__file__))')
3030
find . -name "test*.py" | sed 's,\.\/,coverage run --source='"${itrex_path}"' --append ,g' | sed 's/$/ --verbose/' >run.sh
31+
echo -e '
32+
# Kill the neuralchat server processes
33+
ports="7000 8000 9000"
34+
# Loop through each port and find associated PIDs
35+
for port in $ports; do
36+
# Use lsof to find the processes associated with the port
37+
pids=$(lsof -ti :$port)
38+
39+
if [ -n "$pids" ]; then
40+
echo "Processes running on port $port: $pids"
41+
# Terminate the processes gracefully with SIGTERM
42+
kill $pids
43+
echo "Terminated processes on port $port."
44+
else
45+
echo "No processes found on port $port."
46+
fi
47+
done
48+
' >> run.sh
3149
coverage erase
3250

3351
# run UT
@@ -52,6 +70,9 @@ function pytest() {
5270

5371
function main() {
5472
bash /intel-extension-for-transformers/.github/workflows/script/unitTest/env_setup.sh
73+
apt-get update
74+
apt-get install ffmpeg -y
75+
apt-get install lsof
5576
wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
5677
dpkg -i libssl1.1_1.1.1f-1ubuntu2.19_amd64.deb
5778
python -m pip install --upgrade --force-reinstall torch

.github/workflows/unit-test-neuralchat.yml

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,27 @@ env:
2424

2525
jobs:
2626
unit-test:
27-
runs-on: [self-hosted, Linux, X64, itrex-node]
27+
runs-on: [self-hosted, Linux, X64, neuralchat-node]
2828
strategy:
2929
matrix:
3030
include:
3131
- test_branch: ${{ github.ref }}
3232
test_name: "PR-test"
33-
- test_branch: "main"
34-
test_name: "baseline"
33+
# - test_branch: "main"
34+
# test_name: "baseline"
3535
steps:
36-
- name: Docker Clean Up
36+
- name: podman Clean Up
3737
run: |
38-
docker ps -a
39-
if [[ $(docker ps -a | grep -i '${{ env.CONTAINER_NAME }}'$) ]]; then
40-
docker start ${{ env.CONTAINER_NAME }}
38+
podman ps -a
39+
if [[ $(podman ps -a | grep -i '${{ env.CONTAINER_NAME }}'$) ]]; then
40+
podman start ${{ env.CONTAINER_NAME }}
4141
echo "remove left files through container ..."
42-
docker exec ${{ env.CONTAINER_NAME }} bash -c "ls -a /intel-extension-for-transformers && rm -fr /intel-extension-for-transformers/* && rm -fr /intel-extension-for-transformers/.* || true"
42+
podman exec ${{ env.CONTAINER_NAME }} bash -c "ls -a /intel-extension-for-transformers && rm -fr /intel-extension-for-transformers/* && rm -fr /intel-extension-for-transformers/.* || true"
4343
fi
44-
if [[ $(docker ps -a | grep -i '${{ env.EXTRA_CONTAINER_NAME }}'$) ]]; then
45-
docker start ${{ env.EXTRA_CONTAINER_NAME }}
44+
if [[ $(podman ps -a | grep -i '${{ env.EXTRA_CONTAINER_NAME }}'$) ]]; then
45+
podman start ${{ env.EXTRA_CONTAINER_NAME }}
4646
echo "remove left files through container ..."
47-
docker exec ${{ env.EXTRA_CONTAINER_NAME }} bash -c "ls -a /intel-extension-for-transformers && rm -fr /intel-extension-for-transformers/* && rm -fr /intel-extension-for-transformers/.* || true"
47+
podman exec ${{ env.EXTRA_CONTAINER_NAME }} bash -c "ls -a /intel-extension-for-transformers && rm -fr /intel-extension-for-transformers/* && rm -fr /intel-extension-for-transformers/.* || true"
4848
fi
4949
5050
- name: Checkout out Repo
@@ -54,28 +54,28 @@ jobs:
5454
ref: ${{ matrix.test_branch }}
5555
fetch-tags: true
5656

57-
- name: Docker Build
57+
- name: podman Build
5858
run: |
59-
docker build -f ${{ github.workspace }}/.github/workflows/docker/${{ env.DOCKER_FILE_NAME }}.dockerfile -t ${{ env.REPO_NAME }}:${{ env.REPO_TAG }} .
59+
podman build -f ${{ github.workspace }}/.github/workflows/docker/${{ env.DOCKER_FILE_NAME }}.dockerfile -t ${{ env.REPO_NAME }}:${{ env.REPO_TAG }} .
6060
61-
- name: Docker Run
61+
- name: podman Run
6262
run: |
63-
if [[ $(docker ps -a | grep -i '${{ env.CONTAINER_NAME }}'$) ]]; then
64-
docker stop ${{ env.CONTAINER_NAME }}
65-
docker rm -vf ${{ env.CONTAINER_NAME }} || true
63+
if [[ $(podman ps -a | grep -i '${{ env.CONTAINER_NAME }}'$) ]]; then
64+
podman stop ${{ env.CONTAINER_NAME }}
65+
podman rm -vf ${{ env.CONTAINER_NAME }} || true
6666
fi
67-
docker run -dit --disable-content-trust --privileged --name=${{ env.CONTAINER_NAME }} -v /dev/shm:/dev/shm \
67+
podman run -dit --disable-content-trust --privileged --name=${{ env.CONTAINER_NAME }} -v /dev/shm:/dev/shm \
6868
-v ${{ github.workspace }}:/intel-extension-for-transformers \
6969
${{ env.REPO_NAME }}:${{ env.REPO_TAG }}
7070
7171
- name: Env build
7272
run: |
73-
docker exec ${{ env.CONTAINER_NAME }} \
73+
podman exec ${{ env.CONTAINER_NAME }} \
7474
bash /intel-extension-for-transformers/.github/workflows/script/prepare_env.sh
7575
7676
- name: Binary build
7777
run: |
78-
docker exec ${{ env.CONTAINER_NAME }} \
78+
podman exec ${{ env.CONTAINER_NAME }} \
7979
bash -c "cd /intel-extension-for-transformers/.github/workflows/script \
8080
&& bash install_binary.sh \
8181
&& pip install intel_extension_for_pytorch wget sentencepiece \
@@ -91,14 +91,14 @@ jobs:
9191
9292
- name: Run UT
9393
run: |
94-
docker exec ${{ env.CONTAINER_NAME }} \
94+
podman exec ${{ env.CONTAINER_NAME }} \
9595
bash -c "cd /intel-extension-for-transformers/.github/workflows/script/unitTest \
9696
&& bash run_unit_test_neuralchat.sh --test_name=${{ matrix.test_name }}"
9797
9898
- name: Collect log
9999
if: ${{ !cancelled() }}
100100
run: |
101-
docker exec ${{ env.CONTAINER_NAME }} \
101+
podman exec ${{ env.CONTAINER_NAME }} \
102102
bash -c "cd /intel-extension-for-transformers && \
103103
mv /log_dir . "
104104

intel_extension_for_transformers/neural_chat/config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers import TrainingArguments, BitsAndBytesConfig
2222
from transformers.utils.versions import require_version
2323
from dataclasses import dataclass
24+
from .utils.common import get_device_type
2425

2526
from .plugins import plugins
2627

@@ -429,9 +430,15 @@ def __init__(self,
429430
self.model_name_or_path = model_name_or_path
430431
self.tokenizer_name_or_path = tokenizer_name_or_path
431432
self.hf_access_token = hf_access_token
432-
self.device = device
433+
if device == "auto":
434+
self.device = get_device_type()
435+
else:
436+
self.device = device
437+
433438
self.plugins = plugins
434-
self.loading_config = loading_config if loading_config is not None else LoadingModelConfig()
439+
self.loading_config = loading_config if loading_config is not None else \
440+
LoadingModelConfig(cpu_jit=True if self.device == "cpu" else False, \
441+
use_hpu_graphs = True if self.device == "hpu" else False)
435442
self.optimization_config = optimization_config if optimization_config is not None else AMPConfig()
436443
assert type(self.optimization_config) in [AMPConfig, WeightOnlyQuantizationConfig, BitsAndBytesConfig], \
437444
f"Expect optimization_config be an object of AMPConfig, WeightOnlyQuantizationConfig" + \

intel_extension_for_transformers/neural_chat/pipeline/plugins/audio/tts.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class TextToSpeech():
3333
"""Convert text to speech with a driven speaker embedding
3434
3535
1) Default voice (Original model + Proved good default speaker embedding from trained dataset)
36-
2) Finetuned voice (Fine-tuned offline model of specific person, such as Pat's voice + corresponding embedding)
36+
2) Finetuned voice (Fine-tuned offline model of specific person's voice + corresponding embedding)
3737
3) Customized voice (Original model + User's customized input voice embedding)
3838
"""
3939
def __init__(self, output_audio_path="./response.wav", voice="default", stream_mode=False, device="cpu",
@@ -66,16 +66,16 @@ def __init__(self, output_audio_path="./response.wav", voice="default", stream_m
6666
self.default_speaker_embedding = torch.load(default_speaker_embedding_path)
6767

6868
# preload the demo model in case of time-consuming runtime loading
69-
self.pat_model = None
70-
if os.path.exists("pat.pt"):
71-
self.pat_model = torch.load("pat.pt", map_location=device)
69+
self.demo_model = None
70+
if os.path.exists("demo_model.pt"):
71+
self.demo_model = torch.load("demo_model.pt", map_location=device)
7272

73-
self.pat_speaker_embeddings = None
74-
pat_speaker_embedding_path = os.path.join(script_dir, '../../../assets/speaker_embeddings/spk_embed_pat.pt')
73+
self.male_speaker_embeddings = None
74+
pat_speaker_embedding_path = os.path.join(script_dir, '../../../assets/speaker_embeddings/spk_embed_male.pt')
7575
if os.path.exists(pat_speaker_embedding_path):
76-
self.pat_speaker_embeddings = torch.load(pat_speaker_embedding_path)
77-
elif os.path.exists(os.path.join(asset_path, 'speaker_embeddings/spk_embed_pat.pt')):
78-
self.pat_speaker_embeddings = torch.load(os.path.join(asset_path, 'speaker_embeddings/spk_embed_pat.pt'))
76+
self.male_speaker_embeddings = torch.load(pat_speaker_embedding_path)
77+
elif os.path.exists(os.path.join(asset_path, 'speaker_embeddings/spk_embed_male.pt')):
78+
self.male_speaker_embeddings = torch.load(os.path.join(asset_path, 'speaker_embeddings/spk_embed_male.pt'))
7979

8080
self.cpu_pool = None
8181
if not torch.cuda.is_available():
@@ -148,7 +148,7 @@ def text2speech(self, text, output_audio_path, voice="default", do_batch_tts=Fal
148148
"""Text to speech.
149149
150150
text: the input text
151-
voice: default/pat/huma/tom/eric...
151+
voice: default/male/female/...
152152
batch_length: the batch length for spliting long texts into batches to do text to speech
153153
"""
154154
print(text)
@@ -164,15 +164,15 @@ def text2speech(self, text, output_audio_path, voice="default", do_batch_tts=Fal
164164
print(texts)
165165
model = self.original_model
166166
speaker_embeddings = self.default_speaker_embedding
167-
if voice == "pat":
168-
if self.pat_model == None:
167+
if voice == "male":
168+
if self.demo_model == None:
169169
print("Finetuned model is not found! Use the default one")
170170
else:
171-
model = self.pat_model
172-
if self.pat_speaker_embeddings == None:
173-
print("Pat's speaker embedding is not found! Use the default one")
171+
model = self.demo_model
172+
if self.male_speaker_embeddings == None:
173+
print("Male speaker embedding is not found! Use the default one")
174174
else:
175-
speaker_embeddings = self.pat_speaker_embeddings
175+
speaker_embeddings = self.male_speaker_embeddings
176176
elif voice != "default":
177177
speaker_embeddings = torch.load(self._lookup_voice_embedding(voice))
178178
all_speech = np.array([])

intel_extension_for_transformers/neural_chat/server/config/neuralchat.yaml

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,56 @@
2323
host: 0.0.0.0
2424
port: 8000
2525

26-
# task choices = ['textchat', 'voicechat', 'retrieval', 'text2image', 'finetune']
27-
tasks_list: ['textchat', 'finetune', 'retrieval'] # text chatbot with document retrieval
26+
model_name_or_path: "meta-llama/Llama-2-7b-chat-hf"
27+
device: "auto"
28+
29+
asr:
30+
enable: true
31+
args:
32+
# support cpu, hpu, xpu, cuda
33+
device: "cpu"
34+
# support openai/whisper series
35+
model_name_or_path: "openai/whisper-small"
36+
# only can be set to true when the device is set to "cpu"
37+
bf16: false
38+
39+
tts:
40+
enable: true
41+
args:
42+
device: "cpu"
43+
voice: "default"
44+
stream_mode: false
45+
output_audio_path: "./output_audio.wav"
2846

29-
# plugins choices = ['audio', 'retrieval', 'caching', 'memory_controller', 'intent_detection', 'safety_checker']
30-
plugins_list: ['audio', 'retrieval', 'caching']
47+
asr_chinese:
48+
enable: false
3149

32-
audio:
33-
audio_input: true
34-
audio_output: true
35-
language: "english"
50+
tts_chinese:
51+
enable: false
52+
args:
53+
device: "cpu"
54+
spk_id: 0
55+
stream_mode: false
56+
output_audio_path: "./output_audio.wav"
3657

3758
retrieval:
38-
retrieval_type: "dense"
39-
retrieval_document_path: "../../assets/docs/"
59+
enable: true
60+
args:
61+
retrieval_type: "dense"
62+
input_path: "../../assets/docs/"
63+
embedding_model: "hkunlp/instructor-large"
64+
persist_dir: "./output"
65+
max_length: 512
66+
process: true
4067

41-
caching:
42-
cache_chat_config_file: "../../plugins/caching/cache_config.yaml"
43-
cache_embedding_model_dir: "hkunlp/instructor-large"
68+
cache:
69+
enable: true
70+
args:
71+
config_dir: "../../pipeline/plugins/caching/cache_config.yaml"
72+
embedding_model_dir: "hkunlp/instructor-large"
4473

45-
model_name: "meta-llama/Llama-2-7b-chat-hf"
74+
safety_checker:
75+
enable: true
76+
77+
# task choices = ['textchat', 'voicechat', 'retrieval', 'text2image', 'finetune']
78+
tasks_list: ['textchat', 'retrieval']

intel_extension_for_transformers/neural_chat/server/neuralchat_server.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import argparse
1919
import sys
20-
import os
2120
from typing import List
2221

2322

@@ -36,7 +35,7 @@
3635
from .restful.api import setup_router
3736
from ..config import PipelineConfig
3837
from ..chatbot import build_chatbot
39-
38+
from ..plugins import plugins
4039

4140
__all__ = ['NeuralChatServerExecutor']
4241

@@ -97,31 +96,23 @@ def init(self, config):
9796
Returns:
9897
bool:
9998
"""
100-
plugin_list = list(plugin for plugin in config.plugins_list)
101-
params = {}
102-
# Model configuration
103-
if config.model_name:
104-
params["model_name_or_path"] = config.model_name
105-
# Audio plugin configuration
106-
if "audio" in plugin_list:
107-
params["audio_input"] = config.audio.audio_input
108-
params["audio_output"] = config.audio.audio_output
109-
# Retrieval plugin configuration
110-
if "retrieval" in plugin_list:
111-
params["retrieval_type"] = config.retrieval.retrieval_type
112-
script_dir = os.path.dirname(os.path.abspath(__file__))
113-
retrieval_document_path = os.path.join(script_dir, config.retrieval.retrieval_document_path)
114-
params["retrieval_document_path"] = retrieval_document_path
115-
# Caching plugin configuration
116-
if "caching" in plugin_list:
117-
params["cache_chat_config_file"] = config.caching.cache_chat_config_file
118-
script_dir = os.path.dirname(os.path.abspath(__file__))
119-
retrieval_document_path = os.path.join(script_dir, config.caching.cache_embedding_model_dir)
120-
params["cache_embedding_model_dir"] = retrieval_document_path
121-
# Other plugins configurations
122-
for plugin in ["memory_controller", "intent_detection", "safety_checker"]:
123-
if plugin in config.plugins_list:
124-
params[plugin] = True
99+
device = config.get("device", "auto")
100+
model_name_or_path = config.get("model_name_or_path", "meta-llama/Llama-2-7b-hf")
101+
102+
# Update plugins based on YAML configuration
103+
for plugin_name, plugin_config in plugins.items():
104+
yaml_config = config.get(plugin_name, {})
105+
if yaml_config.get("enable"):
106+
plugin_config["enable"] = True
107+
plugin_config["args"] = yaml_config.get("args", {})
108+
109+
# Create a dictionary of parameters for PipelineConfig
110+
params = {
111+
"model_name_or_path": model_name_or_path,
112+
"device": device,
113+
"plugins": plugins
114+
}
115+
125116
pipeline_config = PipelineConfig(**params)
126117
self.chatbot = build_chatbot(pipeline_config)
127118

@@ -150,4 +141,7 @@ def __call__(self,
150141
config = get_config(config_file)
151142
if self.init(config):
152143
logging.basicConfig(filename=log_file, level=logging.INFO)
153-
uvicorn.run(app, host=config.host, port=config.port)
144+
try:
145+
uvicorn.run(app, host=config.host, port=config.port)
146+
except Exception as e:
147+
print(f"Error starting uvicorn: {str(e)}")

intel_extension_for_transformers/neural_chat/server/restful/textchat_api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import Optional
2222
from fastapi import APIRouter
2323
from ...cli.log import logger
24-
from ...config import GenerationConfig
2524
from ...server.restful.openai_protocol import ChatCompletionRequest, ChatCompletionResponse
2625

2726

@@ -75,8 +74,7 @@ async def handle_chat_completion_request(self, request: ChatCompletionRequest) -
7574

7675
try:
7776
logger.info(f"Predicting chat completion using prompt '{request.prompt}'")
78-
config = GenerationConfig(max_new_tokens=64)
79-
response = chatbot.predict(query=request.prompt, config=config)
77+
response = chatbot.predict(query=request.prompt)
8078
except Exception as e:
8179
raise Exception(e)
8280
else:

intel_extension_for_transformers/neural_chat/tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ markdown
3535
rouge_score
3636
openpyxl
3737
numpy==1.23.5
38+
tiktoken==0.4.0

0 commit comments

Comments
 (0)