1- from  typing  import  Callable , List , Optional 
1+ from  typing  import  Callable , List , Optional ,  Set 
22
33import  torch 
4- from  torch .nn  import  Parameter 
54
65from  vllm .logger  import  init_logger 
6+ from  vllm .model_executor .layers .quantization .kernels .scaled_mm  import  (
7+  ScaledMMLinearLayerConfig , choose_scaled_mm_linear_kernel )
78from  vllm .model_executor .layers .quantization .quark .schemes  import  QuarkScheme 
8- from  vllm .model_executor .layers .quantization .utils .w8a8_utils  import  (
9-  apply_int8_linear , convert_to_channelwise )
109from  vllm .model_executor .parameter  import  (BasevLLMParameter ,
1110 ChannelQuantScaleParameter ,
1211 ModelWeightParameter ,
1615
1716
1817class  QuarkW8A8Int8 (QuarkScheme ):
18+  _kernel_backends_being_used : Set [str ] =  set ()
1919
2020 def  __init__ (self , qscheme : str , is_static_input_scheme : Optional [bool ],
2121 input_symmetric : Optional [bool ]):
@@ -28,77 +28,25 @@ def get_min_capability(cls) -> int:
2828 # turing and up 
2929 return  75 
3030
31-  def  process_weights_after_loading (self , layer : torch .nn .Module ) ->  None :
32-  # WEIGHT 
33-  # Cutlass kernels need transposed weight. 
34-  weight  =  layer .weight 
35-  layer .weight  =  Parameter (weight .t (), requires_grad = False )
36- 
37-  # WEIGHT SCALE 
38-  # Cutlass kernels support only per-tensor and per-channel. 
39-  # If we have a fused module (QKV, MLP) with per tensor scales (thus N 
40-  # scales being passed to the kernel), convert to the per-channel case. 
41-  is_fused_module  =  len (self .logical_widths ) >  1 
42-  if  is_fused_module  and  self .qscheme  ==  "per_tensor" :
43-  ws_channelwise  =  convert_to_channelwise (layer .weight_scale ,
44-  self .logical_widths )
45-  layer .weight_scale  =  Parameter (ws_channelwise , requires_grad = False )
46-  else :
47-  layer .weight_scale  =  Parameter (layer .weight_scale .data ,
48-  requires_grad = False )
49-  layer .weight_zero_point  =  None 
50- 
51-  # INPUT SCALE 
52-  if  self .is_static_input_scheme :
53-  if  self .input_symmetric :
54-  layer .input_scale  =  Parameter (layer .input_scale .max (),
55-  requires_grad = False )
56-  layer .input_zero_point  =  None 
57-  else :
58-  # reconstruct the ranges 
59-  int8_traits  =  torch .iinfo (torch .int8 )
60-  azps  =  layer .input_zero_point .to (dtype = torch .int32 )
61-  range_max  =  (layer .input_scale  * 
62-  (int8_traits .max  -  azps )).max ()
63-  range_min  =  (layer .input_scale  * 
64-  (int8_traits .min  -  azps )).min ()
65- 
66-  scale  =  (range_max  -  range_min ) /  (int8_traits .max  - 
67-  int8_traits .min )
68-  layer .input_scale  =  Parameter (scale , requires_grad = False )
69- 
70-  # AZP loaded as int8 but used as int32 
71-  azp  =  (int8_traits .min  - 
72-  range_min  /  scale ).to (dtype = torch .int32 )
73-  layer .input_zero_point  =  Parameter (azp , requires_grad = False )
74- 
75-  else :
76-  layer .input_scale  =  None 
77-  layer .input_zero_point  =  None 
78- 
79-  # azp_adj is the AZP adjustment term, used to account for weights. 
80-  # It does not depend on scales or azp, so it is the same for 
81-  # static and dynamic quantization. 
82-  # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md 
83-  # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md 
84-  if  not  self .input_symmetric :
85-  azp_adj  =  layer .weight .sum (dim = 0 , keepdim = True , dtype = torch .int32 )
86-  if  self .is_static_input_scheme :
87-  # cutlass_w8a8 requires azp to be folded into azp_adj 
88-  # in the per-tensor case 
89-  azp_adj  =  layer .input_zero_point  *  azp_adj 
90- 
91-  layer .azp_adj  =  azp_adj 
92-  else :
93-  layer .azp_adj  =  None 
94- 
9531 def  create_weights (self , layer : torch .nn .Module ,
9632 output_partition_sizes : List [int ],
9733 input_size_per_partition : int ,
9834 params_dtype : torch .dtype , weight_loader : Callable ,
9935 ** kwargs ):
10036 self .logical_widths  =  output_partition_sizes 
10137
38+  scaled_mm_linear_kernel_config  =  ScaledMMLinearLayerConfig (
39+  is_channelwise = (self .qscheme  ==  "per_channel" ),
40+  is_static_input_scheme = (self .is_static_input_scheme  is  True ),
41+  input_symmetric = (self .input_symmetric  is  True ))
42+ 
43+  kernel_type  =  choose_scaled_mm_linear_kernel (
44+  scaled_mm_linear_kernel_config )
45+ 
46+  if  kernel_type .__name__  not  in   self ._kernel_backends_being_used :
47+  logger .info ("Using %s for QuarkW8A8Int8" , kernel_type .__name__ )
48+  self ._kernel_backends_being_used .add (kernel_type .__name__ )
49+ 
10250 # WEIGHT 
10351 weight  =  ModelWeightParameter (data = torch .empty (
10452 sum (output_partition_sizes ),
@@ -117,22 +65,12 @@ def create_weights(self, layer: torch.nn.Module,
11765 dtype = torch .float32 ),
11866 output_dim = 0 ,
11967 weight_loader = weight_loader )
120-  weight_zero_point  =  ChannelQuantScaleParameter (
121-  data = torch .zeros ((sum (output_partition_sizes ), 1 ),
122-  dtype = torch .int8 ),
123-  output_dim = 0 ,
124-  weight_loader = weight_loader )
12568 else :
12669 assert  self .qscheme  ==  "per_tensor" 
12770 weight_scale  =  PerTensorScaleParameter (data = torch .empty (
12871 len (output_partition_sizes ), dtype = torch .float32 ),
12972 weight_loader = weight_loader )
130-  weight_zero_point  =  PerTensorScaleParameter (
131-  data = torch .zeros (len (output_partition_sizes ),
132-  dtype = torch .int8 ),
133-  weight_loader = weight_loader )
13473 layer .register_parameter ("weight_scale" , weight_scale )
135-  layer .register_parameter ("weight_zero_point" , weight_zero_point )
13674
13775 # INPUT SCALE 
13876 if  self .is_static_input_scheme :
@@ -142,24 +80,26 @@ def create_weights(self, layer: torch.nn.Module,
14280 layer .register_parameter ("input_scale" , input_scale )
14381
14482 if  not  self .input_symmetric :
145-  # Note: compressed-tensors  stores the zp using the same dtype 
83+  # Note: quark  stores the zp using the same dtype 
14684 # as the weights 
14785 # AZP loaded as int8 but used as int32 
14886 input_zero_point  =  BasevLLMParameter (
14987 data = torch .empty (1 , dtype = torch .int8 ),
15088 weight_loader = weight_loader )
151-  else :
152-  input_zero_point  =  BasevLLMParameter (
153-  data = torch .zeros (1 , dtype = torch .int8 ),
154-  weight_loader = weight_loader )
155-  layer .register_parameter ("input_zero_point" , input_zero_point )
89+  layer .register_parameter ("input_zero_point" , input_zero_point )
90+ 
91+  self .kernel  =  kernel_type (c = scaled_mm_linear_kernel_config ,
92+  w_q_param_name = "weight" ,
93+  w_s_param_name = "weight_scale" ,
94+  i_s_param_name = "input_scale" ,
95+  i_zp_param_name = "input_zero_point" ,
96+  azp_adj_param_name = "azp_adj" )
97+ 
98+  # Checkpoints are serialized in quark format, which is 
99+  # different from the format the kernel may want. Handle repacking here. 
100+  def  process_weights_after_loading (self , layer : torch .nn .Module ) ->  None :
101+  self .kernel .process_weights_after_loading (layer )
156102
157103 def  apply_weights (self , layer : torch .nn .Module , x : torch .Tensor ,
158104 bias : Optional [torch .Tensor ]) ->  torch .Tensor :
159-  return  apply_int8_linear (input = x ,
160-  weight = layer .weight ,
161-  weight_scale = layer .weight_scale ,
162-  input_scale = layer .input_scale ,
163-  input_zero_point = layer .input_zero_point ,
164-  azp_adj = layer .azp_adj ,
165-  bias = bias )
105+  return  self .kernel .apply_weights (layer , x , bias )
0 commit comments