Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def recover_qparms(self):
qzeros = torch.ops.bestlaop.acquire_woq_packw_info(self.weight, 10)
if bits == 4:
qzeros = qzeros // 16 + 8
else:
qzeros = (qzeros.to(torch.int32) + 128).to(torch.uint8)
else:
qzeros = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,45 @@


def unpack_weight(qweight, scales, qzeros, q_config):
sym = q_config.sym
bits = q_config.bits
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32).unsqueeze(0)

zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)
).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)
if bits == 8:
zeros = zeros.to(torch.int8)
zeros = zeros.to(torch.int8 if sym else torch.uint8)
# due to INC minus one
zeros = zeros + 1
zeros = zeros.reshape(scales.shape)
try:
zeros = zeros.reshape(scales.shape)
except:
# zeros and scales have different iteam numbers.
# remove 1 (due to 0 + 1 in line 68)
zeros = zeros[zeros !=1]
zeros = zeros.reshape(scales.shape)

# due to INC asym return torch.uint8 but backend request int8,
# change it to int8 with offset 128
if not sym and bits == 8:
zeros = (zeros.to(torch.int32) - 128).to(torch.int8)

weight = torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)
).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(weight, (2**bits) - 1, out=weight)

if bits == 8:
weight = weight.to(torch.int8)
# due to INC add shift bias for sym
if sym:
shift_bias = 2 ** (bits - 1)
weight -= shift_bias
weight = weight.to(torch.int8 if sym else torch.uint8)
# due to INC asym return torch.uint8 but backend request int8,
# change it to int8 with offset 128
if not sym:
weight = (weight.to(torch.int32) - 128). to(torch.int8)
return weight, scales, zeros


Expand Down Expand Up @@ -238,7 +259,7 @@ def _replace_linear(
model._modules[name].requires_grad_(False)
if device == "cpu" or device == torch.device("cpu") or device == "auto":
if quantization_config.weight_dtype in \
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int8", "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]:
model._modules[name].set_fp_weights_bias(
module.weight.data,
None if module.bias is None else module.bias.data,
Expand Down Expand Up @@ -506,7 +527,7 @@ def default_calib_func(model):

q_model = replace_linear(model, None, None, config, device=device)
else:
if config.weight_dtype not in ["nf4", "fp4", "int8", "int4_fullrange"]:
if config.weight_dtype not in ["nf4", "fp4", "int4_fullrange"]:
inc_model = inc_model.export_compressed_model(use_optimum_format=True)
inc_model.eval()
q_model = replace_linear(inc_model, None, None, config, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def save_low_bit(
return

if self.quantization_config.weight_dtype not in \
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int8", "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange"]:
convert_model_to_public(self)
os.makedirs(save_directory, exist_ok=True)
# use transformers original `save_pretrained` function
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
else:
model = model_class(config, *model_args, **kwargs)
if config.quantization_config["weight_dtype"] not in \
["fp8_e5m2", "fp8_e4m3", "fp4", "nf4", "int8", "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "fp4", "nf4", "int4_fullrange"]:
model = build_woq_model(model, quantization_config)
else:
model = replace_linear(
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if config.quantization_config["weight_dtype"] not in \
["fp8_e5m2", "fp8_e4m3", "int8", "nf4", "fp4" "int4_fullrange"]:
["fp8_e5m2", "fp8_e4m3", "nf4", "fp4" "int4_fullrange"]:
model = replace_linear(
model,
quantization_config=quantization_config,
Expand Down
2 changes: 1 addition & 1 deletion tests/CI/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def test_quantization_for_llm(self):
)
bit8_model.eval()
output = bit8_model(dummy_input)
self.assertTrue(isclose(float(output[0][0][0][0]), 0.1675747185945511, rel_tol=1e-04))
self.assertTrue(isclose(float(output[0][0][0][0]), 0.16759155690670013, rel_tol=1e-04))

# GPTQ
woq_config = GPTQConfig(bits=4,
Expand Down
10 changes: 8 additions & 2 deletions tests/CI/test_weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@
Trainer
)
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
from intel_extension_for_transformers.transformers.llm.quantization.nn.modules import QuantizedLinearQBits, QuantizedLoraLinearQBits
from intel_extension_for_transformers.transformers.llm.quantization.utils import convert_to_quantized_model, replace_linear
from intel_extension_for_transformers.transformers.llm.quantization.nn.modules import (
QuantizedLinearQBits,
QuantizedLoraLinearQBits
)
from intel_extension_for_transformers.transformers.llm.quantization.utils import (
convert_to_quantized_model,
replace_linear
)
from intel_extension_for_transformers.transformers.llm.utils.generation import _beam_search, _greedy_search
from intel_extension_for_transformers.transformers import RtnConfig

Expand Down