22
33import itertools
44from abc import abstractmethod
5- from typing import Optional , Union
5+ from typing import Any , Literal , Optional , Union
66
77import torch
8+ import torch .nn as nn
89import torch .nn .functional as F
910from torch .nn .parameter import Parameter , UninitializedParameter
1011
@@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
8485 return param [shard_id ], loaded_weight
8586
8687
88+ # TODO(Isotr0py): We might need a more flexible structure to handle
89+ # bitsandbytes shard offsets.
90+ def left_shift_bitsandbytes_4bit_shard (bnb_weight_attrs : dict [str , Any ]):
91+ """
92+ Separate the BitsAndBytes 4-bit shard.
93+
94+ For example, given bnb weight attributes as below:
95+ {
96+ 'bnb_shard_offsets': array([0, 4, 8, 16]),
97+ 'bnb_quant_state': {0: ..., 1: ..., 2: ...},
98+ }
99+
100+ The function will return:
101+ {
102+ 'bnb_shard_offsets': array([0, 4]),
103+ 'bnb_quant_state': {0: ...},
104+ }
105+ and
106+ {
107+ 'bnb_shard_offsets': array([0, 4, 12]),
108+ 'bnb_quant_state': {0: ..., 1: ...},
109+ }
110+ """
111+ shard_offsets = bnb_weight_attrs ["bnb_shard_offsets" ]
112+ offset_l = shard_offsets [:2 ]
113+ offset_r = shard_offsets [1 :] - shard_offsets [1 ]
114+ quant_state_l = {0 : bnb_weight_attrs ["bnb_quant_state" ][0 ]}
115+ quant_state_r = {
116+ i - 1 : bnb_weight_attrs ["bnb_quant_state" ][i ]
117+ for i in range (1 ,
118+ len (shard_offsets ) - 1 )
119+ }
120+ left = dict (bnb_shard_offsets = offset_l , bnb_quant_state = quant_state_l )
121+ right = dict (bnb_shard_offsets = offset_r , bnb_quant_state = quant_state_r )
122+ return left , right
123+
124+
87125class LinearMethodBase (QuantizeMethodBase ):
88126 """Base class for different (maybe quantized) linear methods."""
89127
@@ -1229,7 +1267,24 @@ def extra_repr(self) -> str:
12291267 return s
12301268
12311269
1232- class QKVCrossParallelLinear (torch .nn .Module ):
1270+ class QKVCrossParallelLinear (LinearBase ):
1271+ """Linear layers for efficient cross-attention's QKV transformation.
1272+
1273+ Args:
1274+ hidden_size: input hidden state size of the transformer.
1275+ head_size: size of each attention head.
1276+ total_num_heads: total number of attention query heads.
1277+ total_num_kv_heads: total number of attention key/value heads. If
1278+ None, assume total_num_kv_heads = total_num_heads.
1279+ bias: If true, add bias.
1280+ skip_bias_add: This was added to enable performance optimizations where
1281+ bias can be fused with other element-wise operations. we
1282+ skip adding bias but instead return it.
1283+ params_dtype: Data type for the parameters.
1284+ quant_config: Quantization configure.
1285+ prefix: The name of the layer in the state dict, including all parents
1286+ (e.g. model.layers.0.qkv_proj)
1287+ """
12331288
12341289 def __init__ (self ,
12351290 hidden_size : int ,
@@ -1241,12 +1296,28 @@ def __init__(self,
12411296 params_dtype : Optional [torch .dtype ] = None ,
12421297 quant_config : Optional [QuantizationConfig ] = None ,
12431298 prefix : str = "" ):
1244- super ().__init__ ()
1299+ # input_size and output_size are not used, just for alignment
1300+ input_size = hidden_size
1301+ output_size = (total_num_heads + (total_num_kv_heads or 0 )) * head_size
1302+ super ().__init__ (input_size = input_size ,
1303+ output_size = output_size ,
1304+ skip_bias_add = skip_bias_add ,
1305+ params_dtype = params_dtype ,
1306+ quant_config = quant_config ,
1307+ prefix = prefix )
1308+
1309+ self .quant_config = quant_config
1310+
12451311 # Empty placeholders for loading as a single module.
1246- self .weight = torch .nn .Parameter ()
1247- set_weight_attrs (self .weight , {
1248- "weight_loader" : self .weight_loader_weight ,
1249- })
1312+ placeholder_size = 0
1313+ assert self .quant_method is not None
1314+ self .quant_method .create_weights (self ,
1315+ placeholder_size , [placeholder_size ],
1316+ placeholder_size ,
1317+ placeholder_size ,
1318+ self .params_dtype ,
1319+ weight_loader = self .weight_loader )
1320+
12501321 # Use a dictionary to avoid submodules parameters auto-registration:
12511322 # drop-in replacement for a `QKVParallelLinear` module.
12521323 self .proj = dict ()
@@ -1276,18 +1347,94 @@ def __init__(self,
12761347 if bias :
12771348 self .bias = torch .nn .Parameter ()
12781349 set_weight_attrs (self .bias , {
1279- "weight_loader" : self .weight_loader_bias ,
1350+ "output_dim" : 0 ,
1351+ "weight_loader" : self .weight_loader ,
12801352 })
1353+ else :
1354+ self .bias = None
12811355
12821356 @property
1283- def q_proj_decoder (self ):
1284- return self .proj ["q_proj_decoder" ]
1357+ def q_proj_decoder (self ) -> ColumnParallelLinear :
1358+ layer = self .proj ["q_proj_decoder" ]
1359+ for name , param in self .named_parameters ():
1360+ target_param = getattr (layer , name )
1361+ self .sync_weight_attrs (param , target_param , mode = "q_proj_decoder" )
1362+ return layer
12851363
12861364 @property
1287- def kv_proj_encoder (self ):
1288- return self .proj ["kv_proj_encoder" ]
1365+ def kv_proj_encoder (self ) -> QKVParallelLinear :
1366+ layer = self .proj ["kv_proj_encoder" ]
1367+ for name , param in self .named_parameters ():
1368+ target_param = getattr (layer , name )
1369+ self .sync_weight_attrs (param , target_param , mode = "kv_proj_encoder" )
1370+ return layer
1371+
1372+ def sync_weight_attrs (
1373+ self ,
1374+ src_param : nn .Parameter ,
1375+ tgt_param : nn .Parameter ,
1376+ mode : Literal ["q_proj_decoder" , "kv_proj_encoder" ],
1377+ ):
1378+ missing_attrs_dict = {
1379+ k : getattr (src_param , k )
1380+ for k in (set (src_param .__dict__ .keys ()) -
1381+ set (tgt_param .__dict__ .keys ()))
1382+ }
1383+ # TODO(Isotr0py): handle bitsandbytes 8bit
1384+ use_bitsandbytes_4bit = getattr (src_param , "use_bitsandbytes_4bit" ,
1385+ False )
1386+ if (missing_attrs_dict and use_bitsandbytes_4bit ):
1387+ q_proj_attrs , kv_proj_attrs = left_shift_bitsandbytes_4bit_shard (
1388+ missing_attrs_dict )
1389+ if mode == "q_proj_decoder" :
1390+ set_weight_attrs (tgt_param , q_proj_attrs )
1391+ elif mode == "kv_proj_encoder" :
1392+ set_weight_attrs (tgt_param , kv_proj_attrs )
1393+ else :
1394+ set_weight_attrs (tgt_param , missing_attrs_dict )
12891395
1290- def forward (self , decoder_hidden_states , encoder_hidden_states ):
1396+ def _is_same_param (
1397+ self ,
1398+ src_param : torch .nn .Parameter ,
1399+ map_param : torch .nn .Parameter ,
1400+ ) -> bool :
1401+ """Check if two parameters are exactly pointing to same things."""
1402+ # ignore weight_loader because it's always different
1403+ key_to_ignore = ["weight_loader" , "_weight_loader" ]
1404+ has_same_type_name = type (src_param ) is type (map_param )
1405+ src_param_attrs = {
1406+ k : v
1407+ for k , v in src_param .__dict__ .items () if k not in key_to_ignore
1408+ }
1409+ map_param_attrs = {
1410+ k : v
1411+ for k , v in map_param .__dict__ .items () if k not in key_to_ignore
1412+ }
1413+ has_same_attrs = src_param_attrs == map_param_attrs
1414+ return has_same_type_name and has_same_attrs
1415+
1416+ def select_proj_params (
1417+ self ,
1418+ layer : nn .Module ,
1419+ param : nn .Parameter ,
1420+ ) -> nn .Parameter :
1421+ """
1422+ Given the placeholder param,
1423+ return the corresponding param in the proj layers.
1424+ """
1425+ target_param_list = [
1426+ v for _ , v in layer .named_parameters ()
1427+ if self ._is_same_param (param , v )
1428+ ]
1429+ assert len (target_param_list ) == 1
1430+ target_param = target_param_list [0 ]
1431+ return target_param
1432+
1433+ def forward ( # type: ignore[override]
1434+ self ,
1435+ decoder_hidden_states : torch .Tensor ,
1436+ encoder_hidden_states : torch .Tensor ,
1437+ ) -> tuple [torch .Tensor , ...]:
12911438 q , _ = self .q_proj_decoder (decoder_hidden_states )
12921439 if encoder_hidden_states is None :
12931440 # Encoder KV already cached.
@@ -1300,25 +1447,21 @@ def forward(self, decoder_hidden_states, encoder_hidden_states):
13001447 k , v = kv_enc .split (self .kv_size , dim = - 1 )
13011448 return q , k , v
13021449
1303- def weight_loader_weight (self ,
1304- param : torch .nn .Parameter ,
1305- loaded_weight : torch .Tensor ,
1306- loaded_shard_id : Optional [str ] = None ):
1307- # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
1308- param = self .q_proj_decoder .weight if loaded_shard_id == "q" \
1309- else self .kv_proj_encoder .weight
1310- param .weight_loader (
1311- param ,
1312- loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
1313- param , loaded_weight , loaded_shard_id )
1314-
1315- def weight_loader_bias (self ,
1316- param : torch .nn .Parameter ,
1317- loaded_weight : torch .Tensor ,
1318- loaded_shard_id : Optional [str ] = None ):
1319- param = self .q_proj_decoder .bias if loaded_shard_id == "q" \
1320- else self .kv_proj_encoder .bias
1321- param .weight_loader (
1322- param ,
1323- loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
1324- param , loaded_weight , loaded_shard_id )
1450+ def weight_loader (self ,
1451+ param : torch .nn .Parameter ,
1452+ loaded_weight : torch .Tensor ,
1453+ loaded_shard_id : Optional [str ] = None ):
1454+ layer = (self .q_proj_decoder
1455+ if loaded_shard_id == "q" else self .kv_proj_encoder )
1456+ target_param = self .select_proj_params (layer , param )
1457+ shard_id_args = (loaded_shard_id , ) if loaded_shard_id != "q" else ()
1458+ layer .weight_loader (target_param , loaded_weight , * shard_id_args )
1459+
1460+ def extra_repr (self ) -> str :
1461+ s = f"in_features={ self .input_size } "
1462+ s += f", q_size={ self .q_proj_decoder .output_size_per_partition } "
1463+ s += f", kv_size={ self .kv_size } "
1464+ s += f", bias={ self .bias is not None } "
1465+ s += f", tp_size={ get_tensor_model_parallel_world_size ()} "
1466+ s += ", gather_output=False"
1467+ return s
0 commit comments