Skip to content

Commit e392d85

Browse files
authored
[Core] Refactor QKVCrossParallelLinear implementation to support BNB 4-bit quantization (#14545)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 77a318b commit e392d85

File tree

3 files changed

+234
-65
lines changed

3 files changed

+234
-65
lines changed

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
1919
_ImageAssets)
20+
from ....quantization.utils import is_quant_method_supported
2021
from ....utils import large_gpu_test
2122
from ...utils import check_logprobs_close
2223

@@ -397,6 +398,50 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
397398
)
398399

399400

401+
@large_gpu_test(min_gb=48)
402+
@pytest.mark.core_model
403+
@pytest.mark.parametrize("model", models)
404+
@pytest.mark.parametrize("dtype", ["float16"])
405+
@pytest.mark.parametrize("max_tokens", [32])
406+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
407+
reason='bitsandbytes is not supported on this GPU type.')
408+
def test_bnb_regression(
409+
image_assets: _ImageAssets,
410+
model: str,
411+
dtype: str,
412+
max_tokens: int,
413+
):
414+
stop_sign = image_assets[0].pil_image
415+
prompts = [
416+
{
417+
"prompt": "<|begin_of_text|>The content of the image <|image|> is",
418+
"multi_modal_data": {
419+
"image": stop_sign
420+
},
421+
},
422+
{
423+
"prompt":
424+
"The color of the sky is blue but sometimes it can also be",
425+
},
426+
]
427+
# Test regression about QKVCrossParallelLinear
428+
llm = LLM(
429+
model=model,
430+
dtype=dtype,
431+
max_model_len=4096,
432+
max_num_seqs=2,
433+
enforce_eager=True,
434+
quantization="bitsandbytes",
435+
load_format="bitsandbytes",
436+
)
437+
sampling_params = SamplingParams(
438+
temperature=0,
439+
max_tokens=max_tokens,
440+
)
441+
outputs = llm.generate(prompts, sampling_params)
442+
assert outputs
443+
444+
400445
@large_gpu_test(min_gb=48)
401446
@pytest.mark.core_model
402447
@pytest.mark.parametrize("model", models)

vllm/model_executor/layers/linear.py

Lines changed: 178 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import itertools
44
from abc import abstractmethod
5-
from typing import Optional, Union
5+
from typing import Any, Literal, Optional, Union
66

77
import torch
8+
import torch.nn as nn
89
import torch.nn.functional as F
910
from torch.nn.parameter import Parameter, UninitializedParameter
1011

@@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
8485
return param[shard_id], loaded_weight
8586

8687

88+
# TODO(Isotr0py): We might need a more flexible structure to handle
89+
# bitsandbytes shard offsets.
90+
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
91+
"""
92+
Separate the BitsAndBytes 4-bit shard.
93+
94+
For example, given bnb weight attributes as below:
95+
{
96+
'bnb_shard_offsets': array([0, 4, 8, 16]),
97+
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
98+
}
99+
100+
The function will return:
101+
{
102+
'bnb_shard_offsets': array([0, 4]),
103+
'bnb_quant_state': {0: ...},
104+
}
105+
and
106+
{
107+
'bnb_shard_offsets': array([0, 4, 12]),
108+
'bnb_quant_state': {0: ..., 1: ...},
109+
}
110+
"""
111+
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
112+
offset_l = shard_offsets[:2]
113+
offset_r = shard_offsets[1:] - shard_offsets[1]
114+
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
115+
quant_state_r = {
116+
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
117+
for i in range(1,
118+
len(shard_offsets) - 1)
119+
}
120+
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
121+
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
122+
return left, right
123+
124+
87125
class LinearMethodBase(QuantizeMethodBase):
88126
"""Base class for different (maybe quantized) linear methods."""
89127

@@ -1229,7 +1267,24 @@ def extra_repr(self) -> str:
12291267
return s
12301268

12311269

1232-
class QKVCrossParallelLinear(torch.nn.Module):
1270+
class QKVCrossParallelLinear(LinearBase):
1271+
"""Linear layers for efficient cross-attention's QKV transformation.
1272+
1273+
Args:
1274+
hidden_size: input hidden state size of the transformer.
1275+
head_size: size of each attention head.
1276+
total_num_heads: total number of attention query heads.
1277+
total_num_kv_heads: total number of attention key/value heads. If
1278+
None, assume total_num_kv_heads = total_num_heads.
1279+
bias: If true, add bias.
1280+
skip_bias_add: This was added to enable performance optimizations where
1281+
bias can be fused with other element-wise operations. we
1282+
skip adding bias but instead return it.
1283+
params_dtype: Data type for the parameters.
1284+
quant_config: Quantization configure.
1285+
prefix: The name of the layer in the state dict, including all parents
1286+
(e.g. model.layers.0.qkv_proj)
1287+
"""
12331288

12341289
def __init__(self,
12351290
hidden_size: int,
@@ -1241,12 +1296,28 @@ def __init__(self,
12411296
params_dtype: Optional[torch.dtype] = None,
12421297
quant_config: Optional[QuantizationConfig] = None,
12431298
prefix: str = ""):
1244-
super().__init__()
1299+
# input_size and output_size are not used, just for alignment
1300+
input_size = hidden_size
1301+
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
1302+
super().__init__(input_size=input_size,
1303+
output_size=output_size,
1304+
skip_bias_add=skip_bias_add,
1305+
params_dtype=params_dtype,
1306+
quant_config=quant_config,
1307+
prefix=prefix)
1308+
1309+
self.quant_config = quant_config
1310+
12451311
# Empty placeholders for loading as a single module.
1246-
self.weight = torch.nn.Parameter()
1247-
set_weight_attrs(self.weight, {
1248-
"weight_loader": self.weight_loader_weight,
1249-
})
1312+
placeholder_size = 0
1313+
assert self.quant_method is not None
1314+
self.quant_method.create_weights(self,
1315+
placeholder_size, [placeholder_size],
1316+
placeholder_size,
1317+
placeholder_size,
1318+
self.params_dtype,
1319+
weight_loader=self.weight_loader)
1320+
12501321
# Use a dictionary to avoid submodules parameters auto-registration:
12511322
# drop-in replacement for a `QKVParallelLinear` module.
12521323
self.proj = dict()
@@ -1276,18 +1347,94 @@ def __init__(self,
12761347
if bias:
12771348
self.bias = torch.nn.Parameter()
12781349
set_weight_attrs(self.bias, {
1279-
"weight_loader": self.weight_loader_bias,
1350+
"output_dim": 0,
1351+
"weight_loader": self.weight_loader,
12801352
})
1353+
else:
1354+
self.bias = None
12811355

12821356
@property
1283-
def q_proj_decoder(self):
1284-
return self.proj["q_proj_decoder"]
1357+
def q_proj_decoder(self) -> ColumnParallelLinear:
1358+
layer = self.proj["q_proj_decoder"]
1359+
for name, param in self.named_parameters():
1360+
target_param = getattr(layer, name)
1361+
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
1362+
return layer
12851363

12861364
@property
1287-
def kv_proj_encoder(self):
1288-
return self.proj["kv_proj_encoder"]
1365+
def kv_proj_encoder(self) -> QKVParallelLinear:
1366+
layer = self.proj["kv_proj_encoder"]
1367+
for name, param in self.named_parameters():
1368+
target_param = getattr(layer, name)
1369+
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
1370+
return layer
1371+
1372+
def sync_weight_attrs(
1373+
self,
1374+
src_param: nn.Parameter,
1375+
tgt_param: nn.Parameter,
1376+
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
1377+
):
1378+
missing_attrs_dict = {
1379+
k: getattr(src_param, k)
1380+
for k in (set(src_param.__dict__.keys()) -
1381+
set(tgt_param.__dict__.keys()))
1382+
}
1383+
# TODO(Isotr0py): handle bitsandbytes 8bit
1384+
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
1385+
False)
1386+
if (missing_attrs_dict and use_bitsandbytes_4bit):
1387+
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
1388+
missing_attrs_dict)
1389+
if mode == "q_proj_decoder":
1390+
set_weight_attrs(tgt_param, q_proj_attrs)
1391+
elif mode == "kv_proj_encoder":
1392+
set_weight_attrs(tgt_param, kv_proj_attrs)
1393+
else:
1394+
set_weight_attrs(tgt_param, missing_attrs_dict)
12891395

1290-
def forward(self, decoder_hidden_states, encoder_hidden_states):
1396+
def _is_same_param(
1397+
self,
1398+
src_param: torch.nn.Parameter,
1399+
map_param: torch.nn.Parameter,
1400+
) -> bool:
1401+
"""Check if two parameters are exactly pointing to same things."""
1402+
# ignore weight_loader because it's always different
1403+
key_to_ignore = ["weight_loader", "_weight_loader"]
1404+
has_same_type_name = type(src_param) is type(map_param)
1405+
src_param_attrs = {
1406+
k: v
1407+
for k, v in src_param.__dict__.items() if k not in key_to_ignore
1408+
}
1409+
map_param_attrs = {
1410+
k: v
1411+
for k, v in map_param.__dict__.items() if k not in key_to_ignore
1412+
}
1413+
has_same_attrs = src_param_attrs == map_param_attrs
1414+
return has_same_type_name and has_same_attrs
1415+
1416+
def select_proj_params(
1417+
self,
1418+
layer: nn.Module,
1419+
param: nn.Parameter,
1420+
) -> nn.Parameter:
1421+
"""
1422+
Given the placeholder param,
1423+
return the corresponding param in the proj layers.
1424+
"""
1425+
target_param_list = [
1426+
v for _, v in layer.named_parameters()
1427+
if self._is_same_param(param, v)
1428+
]
1429+
assert len(target_param_list) == 1
1430+
target_param = target_param_list[0]
1431+
return target_param
1432+
1433+
def forward( # type: ignore[override]
1434+
self,
1435+
decoder_hidden_states: torch.Tensor,
1436+
encoder_hidden_states: torch.Tensor,
1437+
) -> tuple[torch.Tensor, ...]:
12911438
q, _ = self.q_proj_decoder(decoder_hidden_states)
12921439
if encoder_hidden_states is None:
12931440
# Encoder KV already cached.
@@ -1300,25 +1447,21 @@ def forward(self, decoder_hidden_states, encoder_hidden_states):
13001447
k, v = kv_enc.split(self.kv_size, dim=-1)
13011448
return q, k, v
13021449

1303-
def weight_loader_weight(self,
1304-
param: torch.nn.Parameter,
1305-
loaded_weight: torch.Tensor,
1306-
loaded_shard_id: Optional[str] = None):
1307-
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
1308-
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
1309-
else self.kv_proj_encoder.weight
1310-
param.weight_loader(
1311-
param,
1312-
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
1313-
param, loaded_weight, loaded_shard_id)
1314-
1315-
def weight_loader_bias(self,
1316-
param: torch.nn.Parameter,
1317-
loaded_weight: torch.Tensor,
1318-
loaded_shard_id: Optional[str] = None):
1319-
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
1320-
else self.kv_proj_encoder.bias
1321-
param.weight_loader(
1322-
param,
1323-
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
1324-
param, loaded_weight, loaded_shard_id)
1450+
def weight_loader(self,
1451+
param: torch.nn.Parameter,
1452+
loaded_weight: torch.Tensor,
1453+
loaded_shard_id: Optional[str] = None):
1454+
layer = (self.q_proj_decoder
1455+
if loaded_shard_id == "q" else self.kv_proj_encoder)
1456+
target_param = self.select_proj_params(layer, param)
1457+
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1458+
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1459+
1460+
def extra_repr(self) -> str:
1461+
s = f"in_features={self.input_size}"
1462+
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
1463+
s += f", kv_size={self.kv_size}"
1464+
s += f", bias={self.bias is not None}"
1465+
s += f", tp_size={get_tensor_model_parallel_world_size()}"
1466+
s += ", gather_output=False"
1467+
return s

0 commit comments

Comments
 (0)