Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit f49f2d9

Browse files
authored
[LLM Runtime] Add MX-Format (FP8_E5M2, FP8_E4M3, FP4_E2M1, NF4) (#872)
* add fp8 in llm frontend Signed-off-by: Yu, Zhentao <zhentao.yu@intel.com>
1 parent c02dd7b commit f49f2d9

File tree

13 files changed

+156
-39
lines changed

13 files changed

+156
-39
lines changed

.github/workflows/script/models/cpp_graph_inference.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function main() {
2525
quant_script="./build/bin/quant_llama"
2626
infer_cmd="./build/bin/run_llama"
2727
input_model="/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf"
28-
precision_list=("q4_j_b128" "q4_j_b32" "q4_0")
28+
precision_list=("q4_j_b128" "q4_j_b32" "q4_0" "q8e4m3_j_f32_g128_fp8" "q8e5m2_j_f32_g128_fp8" "q4e2m1_j_f32_g128" "nf4_j_f32_g128")
2929
elif [[ "${model}" == "gpt-neox-20b" ]]; then
3030
convert_script="${working_dir}/scripts/convert_gptneox.py"
3131
quant_script="./build/bin/quant_gptneox"
@@ -120,6 +120,14 @@ function main() {
120120
# deprecated since bfloat16 scale not mature
121121
# elif [[ ${precision} == "q4_j_vnni_bf16_b32" ]]; then
122122
# ${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --nthread $cores_per_instance --weight_dtype int4 --group_size 32 --scale_dtype bf16 --compute_dtype int8 --alg sym
123+
elif [[ ${precision} == "q8e4m3_j_f32_g128_fp8" ]]; then
124+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --nthread $cores_per_instance --weight_dtype fp8 --group_size 128 --scale_dtype fp8 --compute_dtype fp32 --alg sym
125+
elif [[ ${precision} == "q8e5m2_j_f32_g128_fp8" ]]; then
126+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --nthread $cores_per_instance --weight_dtype fp8_e5m2 --group_size 128 --scale_dtype fp8 --compute_dtype fp32 --alg sym
127+
elif [[ ${precision} == "q4e2m1_j_f32_g128" ]]; then
128+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --nthread $cores_per_instance --weight_dtype fp4 --group_size 128 --scale_dtype fp32 --compute_dtype fp32 --alg sym
129+
elif [[ ${precision} == "nf4_j_f32_g128" ]]; then
130+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --nthread $cores_per_instance --weight_dtype nf4 --group_size 128 --scale_dtype fp32 --compute_dtype fp32 --alg sym
123131
elif [[ ${precision} == "q4_j_vnni_b32" ]]; then
124132
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --nthread $cores_per_instance --weight_dtype int4 --group_size 32 --scale_dtype fp32 --compute_dtype int8 --alg sym
125133
elif [[ ${precision} == "q4_j_b32" ]]; then

intel_extension_for_transformers/llm/runtime/graph/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,11 @@ Argument description of run.py:
259259
| Argument | Description |
260260
| -------------- | --------------------------------------------------------------------- |
261261
| model | Directory containing model file or model id: String |
262-
| --weight_dtype | Data type of quantized weight: int4/int8 (default int4) |
262+
| --weight_dtype | Data type of quantized weight: int4/int8/fp8(=fp8_e4m3)/fp8_e5m2/fp4(=fp4e2m1)/nf4 (default int4) |
263263
| --alg | Quantization algorithm: sym/asym (default sym) |
264-
| --group_size | Group size: Int (default: 32) |
265-
| --scale_dtype | Data type of scales: fp32/bf16 (dafault fp32) |
266-
| --compute_dtype | Data type of Gemm computation: int8/bf16/fp32 (default: int8) |
264+
| --group_size | Group size: Int, 32/128/-1 (per channel) (default: 32) |
265+
| --scale_dtype | Data type of scales: fp32/bf16/fp8 (dafault fp32) |
266+
| --compute_dtype | Data type of Gemm computation: int8/bf16/fp16/fp32 (default: int8) |
267267
| --use_ggml | Enable ggml for quantization and inference |
268268
| -p / --prompt | Prompt to start generation with: String (default: empty) |
269269
| -n / --n_predict | Number of tokens to predict: Int (default: -1, -1 = infinity) |

intel_extension_for_transformers/llm/runtime/graph/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,6 @@ def init(self, model_name, use_quant=True, use_gptq=False, **quant_kwargs):
8282
self.__import_package(self.model_type)
8383

8484
# check cache and quantization
85-
if use_quant:
86-
if quant_kwargs['weight_dtype'] == "int8" and quant_kwargs['compute_dtype'] == "bf16":
87-
raise ValueError("Error: This combination (weight_dtype=int8, compute_dtype=bf16)"
88-
" is not currently supported. Please use other combinations.")
8985
output_path = "runtime_outs"
9086
os.makedirs(output_path, exist_ok=True)
9187
fp32_bin = "{}/ne_{}_f32.bin".format(output_path, self.model_type)

intel_extension_for_transformers/llm/runtime/graph/application/common.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,12 @@ bool quant_params_parse(int argc, char** argv, quant_params& params) { // NOLIN
673673
params.nthread = std::stoi(argv[++i]);
674674
} else if (arg == "--weight_dtype") {
675675
params.weight_dtype = argv[++i];
676+
if (params.weight_dtype == "fp8") {
677+
params.weight_dtype = "fp8_e4m3";
678+
}
679+
if (params.weight_dtype == "fp4") {
680+
params.weight_dtype = "fp4_e2m1";
681+
}
676682
} else if (arg == "--alg") {
677683
params.alg = argv[++i];
678684
} else if (arg == "--group_size") {

intel_extension_for_transformers/llm/runtime/graph/application/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,15 @@ struct quant_params {
138138
std::string config = "";
139139
int nthread = 1;
140140

141+
// [int4, int8, fp8_e5m2, fp8_e4m3, fp4_e2m1, nf4]
141142
std::string weight_dtype = "int4";
143+
// [sym, asym]
142144
std::string alg = "sym";
145+
// [-1, 32, 128]
143146
int32_t group_size = 32;
147+
// [fp32, bf16, fp8]
144148
std::string scale_dtype = "fp32";
149+
// [fp32, fp16, bf16, int8]
145150
std::string compute_dtype = "int8";
146151
std::string model_name = "unknown";
147152
bool use_ggml = false;

intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -880,15 +880,18 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
880880
if (params.bits == quant_bits::q8) {
881881
quant_type = JBLAS_DTYPE::S8;
882882
}
883-
if (params.bits == quant_bits::fp4) {
883+
if (params.bits == quant_bits::fp4_e2m1) {
884884
quant_type = JBLAS_DTYPE::F4_E2M1;
885885
}
886886
if (params.bits == quant_bits::nf4) {
887887
quant_type = JBLAS_DTYPE::F4_NF4;
888888
}
889-
if (params.bits == quant_bits::fp8) {
889+
if (params.bits == quant_bits::fp8_e4m3) {
890890
quant_type = JBLAS_DTYPE::F8_E4M3;
891891
}
892+
if (params.bits == quant_bits::fp8_e5m2) {
893+
quant_type = JBLAS_DTYPE::F8_E5M2;
894+
}
892895
auto dtype_type = static_cast<JBLAS_DTYPE>(
893896
jblas::utils::jblas_dtype_get_mask_val(quant_type, JBLAS_DTYPE::TypeMask, JBLAS_DTYPE::TypeShift));
894897
if (dtype_type == JBLAS_DTYPE::TypeFloat) {
@@ -906,7 +909,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
906909
if (params.scale_dtype == quant_sdtype::fp16) {
907910
printf("Current not support float16 scale, reset to bf16\n");
908911
}
909-
if (quant_type == JBLAS_DTYPE::F8_E4M3) {
912+
if (quant_type == JBLAS_DTYPE::F8_E4M3 || quant_type == JBLAS_DTYPE::F8_E5M2) {
910913
if (params.scale_dtype != quant_sdtype::fp8) {
911914
printf("Warning: fp8 weight only supports fp8 scale now\n");
912915
}

intel_extension_for_transformers/llm/runtime/graph/models/model_utils/quant_config.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,25 @@
1818
#include "core/data_types.h"
1919
#include "jblas/jit_blas.h"
2020

21-
enum class quant_bits : int { q4 = 0, q8, fp4, nf4, fp8, count };
21+
enum class quant_bits : int { q4 = 0, q8, fp4_e2m1, nf4, fp8_e4m3, fp8_e5m2, count };
2222
static inline quant_bits parse_bits(const std::string& bits) {
2323
if (bits == "int4") {
2424
return quant_bits::q4;
2525
}
2626
if (bits == "int8") {
2727
return quant_bits::q8;
2828
}
29-
if (bits == "fp4") {
30-
return quant_bits::fp4;
29+
if (bits == "fp4_e2m1" || bits == "fp4") {
30+
return quant_bits::fp4_e2m1;
3131
}
3232
if (bits == "nf4") {
3333
return quant_bits::nf4;
3434
}
35-
if (bits == "fp8") {
36-
return quant_bits::fp8;
35+
if (bits == "fp8_e4m3" || bits == "fp8") {
36+
return quant_bits::fp8_e4m3;
37+
}
38+
if (bits == "fp8_e5m2") {
39+
return quant_bits::fp8_e5m2;
3740
}
3841
return quant_bits::count;
3942
}

intel_extension_for_transformers/llm/runtime/graph/scripts/python_api_example.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@
1919
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
2020

2121
model_name = "Intel/neural-chat-7b-v1-1" # or local path to model
22+
# int4 weight_only quantization
2223
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
24+
# nf4 weight_only quantization
25+
# woq_config = WeightOnlyQuantConfig(compute_dtype="fp32", weight_dtype="nf4")
26+
# fp8 weight_only quantization
27+
# woq_config = WeightOnlyQuantConfig(compute_dtype="fp32", weight_dtype="fp8")
2328
prompt = "Once upon a time, a little girl"
2429

2530
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
2631
inputs = tokenizer(prompt, return_tensors="pt").input_ids
2732
streamer = TextStreamer(tokenizer)
2833

29-
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
3034
# top_k_top_p sample or greedy_search
35+
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
3136
outputs = model.generate(inputs, streamer=streamer, max_new_tokens=300)
3237
# beam search
38+
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
3339
outputs = model.generate(inputs, num_beams=4, max_new_tokens=128, min_new_tokens=30, early_stopping=True)
3440
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
3541
print(ans)

intel_extension_for_transformers/llm/runtime/graph/scripts/quantize.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,36 @@ def main(args_in: Optional[List[str]] = None) -> None:
5252
parser.add_argument("--nthread", type=int, help="Number of threads to use: Int (default: 1)", default=1)
5353
parser.add_argument(
5454
"--weight_dtype",
55-
choices=["int4", "int8"],
55+
choices=["int4", "int8", "fp8", "fp8_e5m2", "fp8_e4m3",
56+
"fp4", "fp4_e2m1", "nf4"],
5657
help="Data type of quantized weight: int4/int8 (default: int4)",
5758
default="int4",
5859
)
5960
parser.add_argument(
6061
"--alg",
6162
type=str,
63+
choices=["sym", "asym"],
6264
help="Quantization algorithm to use: sym/asym (default: sym)",
6365
default="sym",
6466
)
65-
parser.add_argument("--group_size", type=int, help="Group size: Int (default: 32)", default=32)
67+
parser.add_argument(
68+
"--group_size",
69+
type=int,
70+
choices=[-1, 32, 128],
71+
help="Group size: Int (default: 32)",
72+
default=32,
73+
)
6674
parser.add_argument(
6775
"--scale_dtype",
6876
type=str,
77+
choices=["fp32", "bf16", "fp8"],
6978
help="Data type of scales: bf16/fp32 (default: fp32)",
7079
default="fp32",
7180
)
7281
parser.add_argument(
7382
"--compute_dtype",
7483
type=str,
84+
choices=["fp32", "fp16", "bf16", "int8"],
7585
help="Data type of Gemm computation: int8/bf16/fp32 (default: int8)",
7686
default="int8",
7787
)
@@ -97,10 +107,28 @@ def main(args_in: Optional[List[str]] = None) -> None:
97107
cmd.extend(["--out_file", args.out_file])
98108
cmd.extend(["--nthread", str(args.nthread)])
99109
cmd.extend(["--weight_dtype", str(args.weight_dtype)])
100-
cmd.extend(["--alg", args.alg])
110+
if (str(args.weight_dtype))[:3] in ["fp8", "fp4", "nf4"] and str(args.alg) in ["asym"]:
111+
print("WARNING: asym alg is not be supported in float quant types. Fall back to sym.");
112+
cmd.extend(["--alg", "sym"])
113+
else:
114+
cmd.extend(["--alg", args.alg])
101115
cmd.extend(["--group_size", str(args.group_size)])
102-
cmd.extend(["--scale_dtype", args.scale_dtype])
103-
cmd.extend(["--compute_dtype", args.compute_dtype])
116+
if (str(args.weight_dtype))[:3] not in ["fp8"]:
117+
sdtype = str(args.scale_dtype)
118+
if str(args.scale_dtype) in ["fp8"]:
119+
print("WARNING: fp8 scale is only be supported in fp8 weight type. Fall back to fp32.");
120+
sdtype = "fp32"
121+
cmd.extend(["--scale_dtype", sdtype])
122+
else:
123+
if str(args.scale_dtype) != "fp8":
124+
print("WARNING: fp8 weight type only supports fp8 scale now.Fall back to fp8.")
125+
cmd.extend(["--scale_dtype", "fp8"])
126+
if (str(args.weight_dtype))[:3] in ["fp8", "fp4", "nf4"] and str(args.compute_dtype) in ["int8"]:
127+
print("WARNING: int8 compute dtype is not be supported in float quant types! "\
128+
"Fall back to fp32.")
129+
cmd.extend(["--compute_dtype", "fp32"])
130+
else:
131+
cmd.extend(["--compute_dtype", args.compute_dtype])
104132
if args.use_ggml:
105133
cmd.extend(["--use_ggml"])
106134

intel_extension_for_transformers/llm/runtime/graph/scripts/run.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,36 @@ def main(args_in: Optional[List[str]] = None) -> None:
4242
# quantization related arguments.
4343
parser.add_argument(
4444
"--weight_dtype",
45-
choices=["int4", "int8"],
45+
choices=["int4", "int8", "fp8", "fp8_e5m2", "fp8_e4m3",
46+
"fp4", "fp4_e2m1", "nf4"],
4647
help="Data type of quantized weight: int4/int8 (default int4)",
4748
default="int4",
4849
)
4950
parser.add_argument(
5051
"--alg",
5152
type=str,
53+
choices=["sym", "asym"],
5254
help="Quantization algorithm: sym/asym (default sym)",
5355
default="sym",
5456
)
55-
parser.add_argument("--group_size", type=int, help="Group size: Int (default: 32)", default=32)
57+
parser.add_argument(
58+
"--group_size",
59+
type=int,
60+
choices=[-1, 32, 128],
61+
help="Group size: Int (default: 32)",
62+
default=32,
63+
)
5664
parser.add_argument(
5765
"--scale_dtype",
5866
type=str,
67+
choices=["fp32", "bf16", "fp8"],
5968
help="Data type of scales: fp32/bf16 (dafault fp32)",
6069
default="fp32",
6170
)
6271
parser.add_argument(
6372
"--compute_dtype",
6473
type=str,
74+
choices=["fp32", "fp16", "bf16", "int8"],
6575
help="Data type of Gemm computation: int8/bf16/fp32 (default: int8)",
6676
default="int8",
6777
)

0 commit comments

Comments
 (0)