Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 3068496

Browse files
authored
[NeuralChat] Fix response issue of model.predict (#1221)
* fix response issue of model.predict Signed-off-by: LetongHan <letong.han@intel.com>
1 parent 87d108b commit 3068496

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

intel_extension_for_transformers/neural_chat/models/model_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,18 +1445,21 @@ def predict(**params):
14451445
output = tokenizer.decode(generation_output[0], skip_special_tokens=True)
14461446
else:
14471447
output = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
1448+
1449+
identifier_index = -1
14481450
if "### Response:" in output:
1449-
return output.split("### Response:")[-1].strip()
1451+
return output.split("### Response:")[identifier_index].strip()
14501452
if "@@ Response" in output:
1451-
return output.split("@@ Response")[-1].strip()
1453+
return output.split("@@ Response")[identifier_index].strip()
14521454
if "### Assistant" in output:
1453-
return output.split("### Assistant:")[-1].strip()
1455+
return output.split("### Assistant:")[identifier_index].strip()
14541456
if "\nassistant\n" in output:
1455-
return output.split("\nassistant\n")[-1].strip()
1457+
return output.split("\nassistant\n")[identifier_index].strip()
14561458
if "[/INST]" in output:
1457-
return output.split("[/INST]")[-1].strip()
1459+
return output.split("[/INST]")[identifier_index].strip()
14581460
if "答:" in output:
1459-
return output.split("答:")[-1].strip()
1461+
return output.split("答:")[identifier_index].strip()
14601462
if "Answer:" in output:
1461-
return output.split("Answer:")[-1].strip()
1463+
return output.split("Answer:")[identifier_index].strip()
1464+
14621465
return output

intel_extension_for_transformers/neural_chat/tests/ci/models/test_model_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import shutil
2121
from unittest import mock
22-
from intel_extension_for_transformers.neural_chat.models.model_utils import load_model, MODELS
22+
from intel_extension_for_transformers.neural_chat.models.model_utils import load_model, MODELS, predict
2323
from intel_extension_for_transformers.transformers import MixedPrecisionConfig, BitsAndBytesConfig, WeightOnlyQuantConfig
2424
from intel_extension_for_transformers.neural_chat.utils.common import get_device_type
2525
from intel_extension_for_transformers.neural_chat.utils.error_utils import clear_latest_error, get_latest_error
@@ -139,5 +139,19 @@ def test_model_optimization_weightonly(self):
139139
self.assertTrue("facebook/opt-125m" in MODELS)
140140
self.assertTrue(MODELS["facebook/opt-125m"]["model"] is not None)
141141

142+
@unittest.skipIf(get_device_type() != 'cpu', "Only run this test on CPU")
143+
def test_model_predict(self):
144+
load_model(model_name="facebook/opt-125m", tokenizer_name="facebook/opt-125m", device="cpu")
145+
self.assertTrue("facebook/opt-125m" in MODELS)
146+
self.assertTrue(MODELS["facebook/opt-125m"]["model"] is not None)
147+
148+
params = {
149+
"model_name": "facebook/opt-125m",
150+
"prompt": "hi"
151+
}
152+
output = predict(**params)
153+
self.assertIn("hi", output)
154+
self.assertNotIn("[/INST]", output)
155+
142156
if __name__ == '__main__':
143157
unittest.main()

0 commit comments

Comments
 (0)