Skip to content

Commit f3d9e58

Browse files
committed
be consistent with origin vllm
Signed-off-by: kewang-xlnx <kewang@xilinx.com>
1 parent 2c61465 commit f3d9e58

File tree

3 files changed

+36
-112
lines changed

3 files changed

+36
-112
lines changed

tests/quantization/test_quark.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
9-
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
9+
QuarkLinearMethod, QuarkW8A8Fp8)
1010

1111

1212
def test_quark_fp8(vllm_runner):
@@ -28,21 +28,3 @@ def test_quark_fp8(vllm_runner):
2828

2929
output = llm.generate_greedy("Hello my name is", max_tokens=20)
3030
assert output
31-
32-
33-
def test_quark_int8(vllm_runner):
34-
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
35-
with vllm_runner(model_path) as llm:
36-
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
37-
layer = model.model.layers[0]
38-
39-
qkv_proj = layer.self_attn.qkv_proj
40-
41-
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
42-
assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)
43-
44-
if isinstance(qkv_proj.scheme, QuarkW8A8Int8):
45-
assert qkv_proj.weight.dtype is torch.int8
46-
47-
output = llm.generate_greedy("Hello my name is", max_tokens=20)
48-
assert output

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
import vllm.model_executor.layers.fused_moe # noqa
66
from vllm import _custom_ops as ops
7+
from vllm.logger import init_logger
78
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
89
FusedMoeWeightScaleSupported)
910
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1011
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
1112
from vllm.model_executor.utils import set_weight_attrs
1213
from vllm.platforms import current_platform
13-
from vllm.utils import print_warning_once
14+
15+
logger = init_logger(__name__)
1416

1517
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"]
1618

@@ -127,7 +129,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
127129
"activation scales are None.")
128130
if (not all_close_1d(layer.w13_input_scale)
129131
or not all_close_1d(layer.w2_input_scale)):
130-
print_warning_once(
132+
logger.warning_once(
131133
"Found input_scales that are not equal for "
132134
"fp8 MoE layer. Using the maximum across experts "
133135
"for each layer. ")
Lines changed: 31 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from typing import Callable, List, Optional
1+
from typing import Callable, List, Optional, Set
22

33
import torch
4-
from torch.nn import Parameter
54

65
from vllm.logger import init_logger
6+
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
7+
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
78
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
8-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
9-
apply_int8_linear, convert_to_channelwise)
109
from vllm.model_executor.parameter import (BasevLLMParameter,
1110
ChannelQuantScaleParameter,
1211
ModelWeightParameter,
@@ -16,6 +15,7 @@
1615

1716

1817
class QuarkW8A8Int8(QuarkScheme):
18+
_kernel_backends_being_used: Set[str] = set()
1919

2020
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
2121
input_symmetric: Optional[bool]):
@@ -28,77 +28,25 @@ def get_min_capability(cls) -> int:
2828
# turing and up
2929
return 75
3030

31-
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
32-
# WEIGHT
33-
# Cutlass kernels need transposed weight.
34-
weight = layer.weight
35-
layer.weight = Parameter(weight.t(), requires_grad=False)
36-
37-
# WEIGHT SCALE
38-
# Cutlass kernels support only per-tensor and per-channel.
39-
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
40-
# scales being passed to the kernel), convert to the per-channel case.
41-
is_fused_module = len(self.logical_widths) > 1
42-
if is_fused_module and self.qscheme == "per_tensor":
43-
ws_channelwise = convert_to_channelwise(layer.weight_scale,
44-
self.logical_widths)
45-
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
46-
else:
47-
layer.weight_scale = Parameter(layer.weight_scale.data,
48-
requires_grad=False)
49-
layer.weight_zero_point = None
50-
51-
# INPUT SCALE
52-
if self.is_static_input_scheme:
53-
if self.input_symmetric:
54-
layer.input_scale = Parameter(layer.input_scale.max(),
55-
requires_grad=False)
56-
layer.input_zero_point = None
57-
else:
58-
# reconstruct the ranges
59-
int8_traits = torch.iinfo(torch.int8)
60-
azps = layer.input_zero_point.to(dtype=torch.int32)
61-
range_max = (layer.input_scale *
62-
(int8_traits.max - azps)).max()
63-
range_min = (layer.input_scale *
64-
(int8_traits.min - azps)).min()
65-
66-
scale = (range_max - range_min) / (int8_traits.max -
67-
int8_traits.min)
68-
layer.input_scale = Parameter(scale, requires_grad=False)
69-
70-
# AZP loaded as int8 but used as int32
71-
azp = (int8_traits.min -
72-
range_min / scale).to(dtype=torch.int32)
73-
layer.input_zero_point = Parameter(azp, requires_grad=False)
74-
75-
else:
76-
layer.input_scale = None
77-
layer.input_zero_point = None
78-
79-
# azp_adj is the AZP adjustment term, used to account for weights.
80-
# It does not depend on scales or azp, so it is the same for
81-
# static and dynamic quantization.
82-
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
83-
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
84-
if not self.input_symmetric:
85-
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
86-
if self.is_static_input_scheme:
87-
# cutlass_w8a8 requires azp to be folded into azp_adj
88-
# in the per-tensor case
89-
azp_adj = layer.input_zero_point * azp_adj
90-
91-
layer.azp_adj = azp_adj
92-
else:
93-
layer.azp_adj = None
94-
9531
def create_weights(self, layer: torch.nn.Module,
9632
output_partition_sizes: List[int],
9733
input_size_per_partition: int,
9834
params_dtype: torch.dtype, weight_loader: Callable,
9935
**kwargs):
10036
self.logical_widths = output_partition_sizes
10137

38+
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
39+
is_channelwise=(self.qscheme == "per_channel"),
40+
is_static_input_scheme=(self.is_static_input_scheme is True),
41+
input_symmetric=(self.input_symmetric is True))
42+
43+
kernel_type = choose_scaled_mm_linear_kernel(
44+
scaled_mm_linear_kernel_config)
45+
46+
if kernel_type.__name__ not in self._kernel_backends_being_used:
47+
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
48+
self._kernel_backends_being_used.add(kernel_type.__name__)
49+
10250
# WEIGHT
10351
weight = ModelWeightParameter(data=torch.empty(
10452
sum(output_partition_sizes),
@@ -117,22 +65,12 @@ def create_weights(self, layer: torch.nn.Module,
11765
dtype=torch.float32),
11866
output_dim=0,
11967
weight_loader=weight_loader)
120-
weight_zero_point = ChannelQuantScaleParameter(
121-
data=torch.zeros((sum(output_partition_sizes), 1),
122-
dtype=torch.int8),
123-
output_dim=0,
124-
weight_loader=weight_loader)
12568
else:
12669
assert self.qscheme == "per_tensor"
12770
weight_scale = PerTensorScaleParameter(data=torch.empty(
12871
len(output_partition_sizes), dtype=torch.float32),
12972
weight_loader=weight_loader)
130-
weight_zero_point = PerTensorScaleParameter(
131-
data=torch.zeros(len(output_partition_sizes),
132-
dtype=torch.int8),
133-
weight_loader=weight_loader)
13473
layer.register_parameter("weight_scale", weight_scale)
135-
layer.register_parameter("weight_zero_point", weight_zero_point)
13674

13775
# INPUT SCALE
13876
if self.is_static_input_scheme:
@@ -142,24 +80,26 @@ def create_weights(self, layer: torch.nn.Module,
14280
layer.register_parameter("input_scale", input_scale)
14381

14482
if not self.input_symmetric:
145-
# Note: compressed-tensors stores the zp using the same dtype
83+
# Note: quark stores the zp using the same dtype
14684
# as the weights
14785
# AZP loaded as int8 but used as int32
14886
input_zero_point = BasevLLMParameter(
14987
data=torch.empty(1, dtype=torch.int8),
15088
weight_loader=weight_loader)
151-
else:
152-
input_zero_point = BasevLLMParameter(
153-
data=torch.zeros(1, dtype=torch.int8),
154-
weight_loader=weight_loader)
155-
layer.register_parameter("input_zero_point", input_zero_point)
89+
layer.register_parameter("input_zero_point", input_zero_point)
90+
91+
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
92+
w_q_param_name="weight",
93+
w_s_param_name="weight_scale",
94+
i_s_param_name="input_scale",
95+
i_zp_param_name="input_zero_point",
96+
azp_adj_param_name="azp_adj")
97+
98+
# Checkpoints are serialized in quark format, which is
99+
# different from the format the kernel may want. Handle repacking here.
100+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
101+
self.kernel.process_weights_after_loading(layer)
156102

157103
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
158104
bias: Optional[torch.Tensor]) -> torch.Tensor:
159-
return apply_int8_linear(input=x,
160-
weight=layer.weight,
161-
weight_scale=layer.weight_scale,
162-
input_scale=layer.input_scale,
163-
input_zero_point=layer.input_zero_point,
164-
azp_adj=layer.azp_adj,
165-
bias=bias)
105+
return self.kernel.apply_weights(layer, x, bias)

0 commit comments

Comments
 (0)