Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/quantized_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ orig_model.linear = q_linear

| Weight Quantization Type | Activation Quantization Type | Dtype | Supported |
|---|---|---|---|
| per-channel | N/A | W8A16 | Yes |
| per-channel | N/A | W4A16 | Yes |
| per-channel (sym/asym) | N/A | W8A16 | Yes |
| per-channel (sym/asym) | N/A | W4A16 | Yes |
| per-channel | per-token | W8A8 | No |
| per-channel | per-token | W4A8 | No |
| blockwise | N/A | W8A16 | Yes |
| blockwise | N/A | W4A16 | Yes |
| blockwise (sym/asym) | N/A | W8A16 | Yes |
| blockwise (sym/asym) | N/A | W4A16 | Yes |
| blockwise | per-token | W8A8 | No |
| blockwise | per-token | W4A8 | No |

Expand Down
117 changes: 80 additions & 37 deletions test/quantized_ops/test_quantized_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,61 +31,66 @@ def weight_quantization_rtn(self,
'''
assert isinstance(self.linear, torch.nn.Linear)
w_fp = linear.weight.data
is_symmetric = quant_method == torch.per_channel_symmetric or quant_method == torch.per_tensor_symmetric
if block_size == -1:
min_val, max_val = torch.aminmax(
w_fp, dim=1) # min_val, max_val [out_dim]
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scaler, zero_point = determine_qparams(
min_val,
max_val,
int_min,
int_max,
dtype=torch.int8,
eps=torch.Tensor([1e-5]),
has_customized_qrange=False,
qscheme=quant_method)
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8)
return w_int, scaler.to(w_fp.dtype), zero_point
else:
assert w_fp.shape[1] % block_size == 0
output_dim = w_fp.shape[0]
input_dim = w_fp.shape[1]
w_fp = w_fp.reshape(output_dim * input_dim // block_size, block_size)
min_val, max_val = torch.aminmax(
w_fp, dim=1) # min_val, max_val [out_dim]
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scaler, zero_point = determine_qparams(
min_val,
max_val,
int_min,
int_max,
dtype=torch.int8,
eps=torch.Tensor([1e-5]),
has_customized_qrange=False,
qscheme=quant_method)
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8)
int_min = -2**(n_bits - 1)
int_max = 2**(n_bits - 1) - 1
scaler, zero_point = determine_qparams(
min_val,
max_val,
int_min,
int_max,
dtype=torch.int8,
eps=torch.Tensor([1e-5]),
has_customized_qrange=False,
qscheme=quant_method)
w_int = torch.ops.quantized_decomposed.quantize_per_channel(
w_fp, scaler, zero_point, 0, int_min, int_max, torch.int8)
scaler = scaler.to(w_fp.dtype)
dq_w = torch.ops.quantized_decomposed.dequantize_per_channel(
w_int, scaler, zero_point, 0, int_min, int_max, dtype=torch.int8)
if not is_symmetric:
zero_point = zero_point.to(w_fp.dtype)
zero_point *= scaler
else:
zero_point = None
if block_size != -1:
w_int = w_int.reshape(output_dim, input_dim // block_size,
block_size).permute(1, 2, 0)
scaler = scaler.to(w_fp.dtype).reshape(output_dim,
input_dim // block_size).permute(
1, 0)
return w_int, scaler, zero_point

def replace_with_xla_quantized_matmul(self, n_bit=8, block_size=-1):
scaler = scaler.reshape(output_dim, input_dim // block_size).permute(1, 0)
if not is_symmetric:
zero_point = zero_point.reshape(output_dim,
input_dim // block_size).permute(1, 0)
return w_int, scaler, zero_point

def replace_with_xla_quantized_matmul(self,
n_bit=8,
block_size=-1,
is_symmetric=True):
assert isinstance(self.linear, torch.nn.Linear)
w_int, scaler, _ = self.weight_quantization_rtn(
self.linear, n_bits=n_bit, block_size=block_size)
w_int, scaler, zero_point = self.weight_quantization_rtn(
self.linear,
n_bits=n_bit,
block_size=block_size,
quant_method=torch.per_channel_symmetric
if is_symmetric else torch.per_channel_affine)
use_int4_weight = n_bit == 4
q_linear = XlaQuantizedLinear(
self.linear.in_features,
self.linear.out_features,
block_size=block_size,
int4_weight=use_int4_weight)
q_linear.load_quantized_weight(w_int, scaler)
int4_weight=use_int4_weight,
is_symmetric=is_symmetric)
q_linear.load_quantized_weight(w_int, scaler, zero_point=zero_point)
self.linear = q_linear

def forward(self, x):
Expand Down Expand Up @@ -225,6 +230,44 @@ def test_blockwise_linear_module(self):
self.assertGreater(
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999)

def test_asymmetric_per_channel(self):
for n_bit in [4, 8]:
with self.subTest(n_bit=n_bit):
m = M(6, 8)
x = torch.randn(3, 6)
out_fp = m(x)
m.replace_with_xla_quantized_matmul(
n_bit=n_bit, block_size=-1, is_symmetric=False)
out_quant = m(x)
self.assertGreater(self._calc_cosine_dist(out_fp, out_quant), 0.99)

# Dot with int4 weight is only supported on TPU
if not (n_bit == 4 and xr.device_type() != 'TPU'):
m = m.to(device)
x = x.to(device)
out_quant_xla = m(x)
self.assertGreater(
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999)

def test_asymmetric_blockwise(self):
for n_bit in [8]:
with self.subTest(n_bit=n_bit):
m = M(6, 8)
x = torch.randn(2, 6)
out_fp = m(x)
m.replace_with_xla_quantized_matmul(
n_bit=n_bit, block_size=2, is_symmetric=False)
out_quant = m(x)
self.assertGreater(self._calc_cosine_dist(out_fp, out_quant), 0.99)

# Dot with int4 weight is only supported on TPU
if not (n_bit == 4 and xr.device_type() != 'TPU'):
m = m.to(device)
x = x.to(device)
out_quant_xla = m(x)
self.assertGreater(
self._calc_cosine_dist(out_quant_xla.cpu(), out_quant), 0.999999)


if __name__ == '__main__':
unittest.main()
75 changes: 60 additions & 15 deletions torch_xla/experimental/xla_quantized_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from torch_xla.core.xla_model import XLA_LIB

XLA_LIB.define(
"quantized_matmul(Tensor x, Tensor w, Tensor scale, int? block_size=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor"
"quantized_matmul(Tensor x, Tensor w, Tensor scale, Tensor? zero_point=None, int? block_size=-1, bool? int4_weight=False, bool? quantize_activation=False) -> Tensor"
)


def _check_per_channel_quant_weight_dtype_shapes(input_dim, output_dim, w,
w_scaler):
w_scaler, zero_point):
assert w.dtype == torch.int8, f"Weight dtype is expected to be torch.int8, got {w.dtype}."
assert w.dim(
) == 2, f"Weight tensor is expected to be 2D, got {w.dim()}D Tensor."
Expand All @@ -19,10 +19,14 @@ def _check_per_channel_quant_weight_dtype_shapes(input_dim, output_dim, w,
1], f"Weight shape is expected to be [output_dim, input_dim], output_dim: {output_dim}, input_dim: {input_dim}, but got {w_shape}."
assert w_scaler.dim() == 1 and w_scaler.shape[0] == w_shape[
0], f"weight scaler shape is expect to be [out_channel,], got {w_scaler.shape}, weight shape {w_shape}."
if zero_point is not None:
assert zero_point.dim() == 1 and w_scaler.shape[0] == w_shape[
0], f"zero point shape is expect to be [out_channel,], got {zero_point.shape}, weight shape {w_shape}."


def _check_blockwise_quant_weight_dtype_shapes(input_dim, output_dim,
block_size, w, w_scaler):
block_size, w, w_scaler,
zero_point):
assert w.dtype == torch.int8, (
f"Weight dtype is expected to be torch.int8, got {w.dtype}.")
assert w.dim() == 3, (
Expand All @@ -40,12 +44,20 @@ def _check_blockwise_quant_weight_dtype_shapes(input_dim, output_dim,
assert w_scaler.shape[0] == w_shape[0] and w_scaler.shape[1] == w_shape[-1], (
f"weight scaler shape is expect to be [in_channel / block_size, out_channel], "
f"got {w_scaler.shape}, weight shape {w_shape}.")
if zero_point is not None:
assert zero_point.dim() == 2, (
f"zero_point is expected to be 2D, got {zero_point.dim()}D Tensor.")
assert zero_point.shape[0] == w_shape[0] and zero_point.shape[1] == w_shape[
-1], (
f"zero_point shape is expect to be [in_channel / block_size, out_channel], "
f"got {zero_point.shape}, weight shape {w_shape}.")


@impl(XLA_LIB, "quantized_matmul", "XLA")
def quantized_matmul_xla(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
zero_point: torch.Tensor = None,
block_size: int = -1,
int4_weight: bool = False):
"""Quantized Matrix Multiply op on XLA devices.
Expand All @@ -58,6 +70,9 @@ def quantized_matmul_xla(x: torch.Tensor,
scaler: torch.Tensor - Weight scaler.
per-channel quant: [out_channel,].
blockwise quant: [in_channel / block_size, out_channel].
zero_point: Optional[torch.Tensor] - Zero point tensor.
per-channel quant: [out_channel,].
blockwise quant: [in_channel / block_size, out_channel].
block_size: The blocksize for blockwise quantization, -1 for per-channel quantization.
int4_weight: if the weights are int4, the int4 weights need to be stored in a int8
container (unpacked).
Expand All @@ -68,58 +83,86 @@ def quantized_matmul_xla(x: torch.Tensor,
if block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0],
w, scaler)
return F.linear(x, w) * scaler
w, scaler, zero_point)
out = F.linear(x, w) * scaler
else:
# Blockwise quant.
_check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1],
block_size, w, scaler)
block_size, w, scaler,
zero_point)
x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size)
out = torch.einsum('scn,...sc->...sn', w, x)
out = torch.einsum('sn,...sn->...n', scaler, out)
return out
if zero_point is not None:
if block_size == -1:
# Per-channel quant.
zp_out = torch.einsum("...c,z->...z", x, zero_point)
else:
# Blockwise quant.
zp_out = x.sum(dim=-1)
zp_out = torch.matmul(zp_out, zero_point)
out -= zp_out
return out


@impl(XLA_LIB, "quantized_matmul", "CompositeExplicitAutograd")
def quantized_matmul(x: torch.Tensor,
w: torch.Tensor,
scaler: torch.Tensor,
zero_point: torch.Tensor = None,
block_size: int = -1,
int4_weight: bool = False):
if block_size == -1:
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(x.shape[-1], scaler.shape[0],
w, scaler)
w, scaler, zero_point)
w = w.to(x.dtype)
return torch.mul(F.linear(x, w), scaler)
out = torch.mul(F.linear(x, w), scaler)
else:
# Blockwise quant.
_check_blockwise_quant_weight_dtype_shapes(x.shape[-1], w.shape[-1],
block_size, w, scaler)
block_size, w, scaler,
zero_point)
x = x.reshape(*x.shape[:-1], x.shape[-1] // block_size, block_size)
w = w.to(x.dtype)
out = torch.einsum('scn,...sc->...sn', w, x)
out = torch.einsum('sn,...sn->...n', scaler, out)
return out

if zero_point is not None:
if block_size == -1:
# Per-channel quant.
zp_out = torch.einsum("...c,z->...z", x, zero_point)
else:
# Blockwise quant.
zp_out = x.sum(dim=-1)
zp_out = torch.matmul(zp_out, zero_point)
out -= zp_out
return out


class XlaQuantizedLinear(torch.nn.Module):

def __init__(self,
input_dim,
output_dim,
block_size=-1,
is_symmetric: bool = False,
block_size: int = -1,
int4_weight: bool = False):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.is_symmetric = is_symmetric
self.block_size = block_size
self.int4_weight = int4_weight
self.register_buffer('weight',
torch.zeros(output_dim, input_dim).to(torch.int8))
self.register_buffer('weight_scaler', torch.zeros(output_dim))
if not self.is_symmetric:
self.register_buffer('zero_point', torch.zeros(output_dim))
else:
self.zero_point = None

def load_quantized_weight(self, weight, weight_scaler):
def load_quantized_weight(self, weight, weight_scaler, zero_point=None):
'''
weight (Tensor):
per-channel quant: [out_channel, in_channel].
Expand All @@ -133,20 +176,22 @@ def load_quantized_weight(self, weight, weight_scaler):
# Per-channel quant.
_check_per_channel_quant_weight_dtype_shapes(self.input_dim,
self.output_dim, weight,
weight_scaler)
weight_scaler, zero_point)
else:
# Blockwise quant.
_check_blockwise_quant_weight_dtype_shapes(self.input_dim,
self.output_dim,
self.block_size, weight,
weight_scaler)
weight_scaler, zero_point)
self.weight = weight
self.weight_scaler = weight_scaler
self.zero_point = zero_point

def forward(self, x):
return torch.ops.xla.quantized_matmul(
x,
self.weight,
self.weight_scaler,
self.zero_point,
block_size=self.block_size,
int4_weight=self.int4_weight)