Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit b676256

Browse files
[NeuralChat] Remove unnecessary model load during optimizing model (#722)
* Remove unnecessary model load during optimizing model Signed-off-by: lvliang-intel <liang1.lv@intel.com>
1 parent 40472f0 commit b676256

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

intel_extension_for_transformers/llm/quantization/optimization.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def __init__(
2525
self.optimization_config = optimization_config
2626

2727
def optimize(self, model, use_llm_runtime=False):
28-
optimized_model = model
28+
if isinstance(model, str):
29+
model_name = model
30+
else:
31+
model_name = model.config._name_or_path
32+
optimized_model = model
2933
from intel_extension_for_transformers.transformers import (
3034
MixedPrecisionConfig,
3135
WeightOnlyQuantConfig,
@@ -35,39 +39,40 @@ def optimize(self, model, use_llm_runtime=False):
3539
f"Expect optimization_config be an object of MixedPrecisionConfig, WeightOnlyQuantConfig" + \
3640
" or BitsAndBytesConfig,got {type(self.optimization_config)}."
3741
config = self.optimization_config
38-
if re.search("flan-t5", model.config._name_or_path, re.IGNORECASE):
42+
if re.search("flan-t5", model_name, re.IGNORECASE):
3943
from intel_extension_for_transformers.transformers import AutoModelForSeq2SeqLM
4044
optimized_model = AutoModelForSeq2SeqLM.from_pretrained(
41-
model.config._name_or_path,
45+
model_name,
4246
quantization_config=config,
4347
use_llm_runtime=use_llm_runtime,
4448
trust_remote_code=True)
4549
elif (
46-
re.search("gpt", model.config._name_or_path, re.IGNORECASE)
47-
or re.search("mpt", model.config._name_or_path, re.IGNORECASE)
48-
or re.search("bloom", model.config._name_or_path, re.IGNORECASE)
49-
or re.search("llama", model.config._name_or_path, re.IGNORECASE)
50-
or re.search("opt", model.config._name_or_path, re.IGNORECASE)
51-
or re.search("neural-chat-7b-v1", model.config._name_or_path, re.IGNORECASE)
52-
or re.search("neural-chat-7b-v2", model.config._name_or_path, re.IGNORECASE)
50+
re.search("gpt", model_name, re.IGNORECASE)
51+
or re.search("mpt", model_name, re.IGNORECASE)
52+
or re.search("bloom", model_name, re.IGNORECASE)
53+
or re.search("llama", model_name, re.IGNORECASE)
54+
or re.search("opt", model_name, re.IGNORECASE)
55+
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)
56+
or re.search("neural-chat-7b-v2", model_name, re.IGNORECASE)
57+
or re.search("neural-chat-7b-v3", model_name, re.IGNORECASE)
5358
):
5459
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
5560
optimized_model = AutoModelForCausalLM.from_pretrained(
56-
model.config._name_or_path,
61+
model_name,
5762
quantization_config=config,
5863
use_llm_runtime=use_llm_runtime,
5964
trust_remote_code=True)
60-
elif re.search("starcoder", model.config._name_or_path, re.IGNORECASE):
65+
elif re.search("starcoder", model_name, re.IGNORECASE):
6166
from intel_extension_for_transformers.transformers import GPTBigCodeForCausalLM
6267
optimized_model = GPTBigCodeForCausalLM.from_pretrained(
63-
model.config._name_or_path,
68+
model_name,
6469
quantization_config=config,
6570
use_llm_runtime=use_llm_runtime,
6671
trust_remote_code=True)
67-
elif re.search("chatglm", model.config._name_or_path, re.IGNORECASE):
72+
elif re.search("chatglm", model_name, re.IGNORECASE):
6873
from intel_extension_for_transformers.transformers import AutoModel
6974
optimized_model = AutoModel.from_pretrained(
70-
model.config._name_or_path,
75+
model_name,
7176
quantization_config=config,
7277
use_llm_runtime=use_llm_runtime,
7378
trust_remote_code=True)

intel_extension_for_transformers/neural_chat/models/model_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,15 @@ def load_model(
332332
config = AutoConfig.from_pretrained(model_name, use_auth_token=hf_access_token, trust_remote_code=True \
333333
if re.search("chatglm", model_name, re.IGNORECASE) else False)
334334
load_to_meta = model_on_meta(config)
335+
336+
if isinstance(optimization_config, WeightOnlyQuantConfig):
337+
from intel_extension_for_transformers.neural_chat.chatbot import optimize_model
338+
model = optimize_model(model_name, optimization_config, use_llm_runtime)
339+
MODELS[model_name]["model"] = model
340+
MODELS[model_name]["tokenizer"] = tokenizer
341+
print("Optimized Model loaded.")
342+
return
343+
335344
if peft_path and device == "hpu" and use_deepspeed and load_to_meta:
336345
print("PEFT could not work in deepspeed sharded checkpt loading mode, set load_to_meta to False")
337346
load_to_meta = False
@@ -426,15 +435,6 @@ def load_model(
426435
if model.generation_config.eos_token_id is None:
427436
model.generation_config.eos_token_id = tokenizer.eos_token_id
428437

429-
if isinstance(optimization_config, WeightOnlyQuantConfig):
430-
from intel_extension_for_transformers.neural_chat.chatbot import optimize_model
431-
model = optimize_model(model, optimization_config, use_llm_runtime)
432-
433-
MODELS[model_name]["model"] = model
434-
MODELS[model_name]["tokenizer"] = tokenizer
435-
print("Optimized Model loaded.")
436-
return
437-
438438
if device == "hpu":
439439
if peft_path:
440440
from peft import PeftModel

0 commit comments

Comments
 (0)