Skip to content

Commit 4c71143

Browse files
committed
fix condition for is_k_full; clean-up
1 parent ba92f95 commit 4c71143

File tree

2 files changed

+46
-26
lines changed

2 files changed

+46
-26
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,26 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
277277
elif shard_id == "w2":
278278
param_data[expert_id] = loaded_weight
279279

280-
def _load_model_weight_or_group_weight_scale(
281-
self, shard_dim: int, expert_data: torch.Tensor, shard_id: str,
282-
loaded_weight: torch.Tensor, tp_rank: int, load_full_w2: bool):
283-
# Load grouped weight scales for group quantization
284-
# or model weights
285-
# In act_order scenario, we need to load full w2 scales
280+
def _load_model_weight_or_group_weight_scale(self,
281+
shard_dim: int,
282+
expert_data: torch.Tensor,
283+
shard_id: str,
284+
loaded_weight: torch.Tensor,
285+
tp_rank: int,
286+
load_full_w2: bool = False):
287+
"""
288+
Load grouped weight scales for group quantization or model weights
289+
:param shard_dim: dimension to shard
290+
:param expert_data: parameter for a particular expert
291+
:param shard_id: either w1, w2, or w3
292+
:param loaded_weight: checkpoint weight to load into the param
293+
:param tp_rank: tensor parallel rank
294+
:param load_full_w2: whether or not the w2 loaded should be sharded.
295+
"""
286296
if shard_id == "w2":
287-
self._load_w2(shard_id=shard_id,
288-
shard_dim=shard_dim,
297+
# In the case where we have actorder/g_idx, we do not partition the
298+
# w2 scales, as indicated by `load_full` argument, for all tp cases
299+
self._load_w2(shard_dim=shard_dim,
289300
loaded_weight=loaded_weight,
290301
expert_data=expert_data,
291302
tp_rank=tp_rank,
@@ -329,9 +340,12 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
329340
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
330341
expert_data.copy_(loaded_weight)
331342

332-
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
333-
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
334-
load_full: bool):
343+
def _load_w2(self,
344+
expert_data: torch.Tensor,
345+
shard_dim: int,
346+
loaded_weight: torch.Tensor,
347+
tp_rank: int,
348+
load_full: bool = False):
335349

336350
# Index the loaded weight for tp sharding.
337351
# down_proj: "RowParallel" so tp sharding on input_dim
@@ -355,12 +369,10 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
355369
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
356370

357371
if shard_id == "w2":
358-
self._load_w2(shard_id=shard_id,
359-
shard_dim=shard_dim,
372+
self._load_w2(shard_dim=shard_dim,
360373
loaded_weight=loaded_weight,
361374
expert_data=expert_data,
362-
tp_rank=tp_rank,
363-
load_full=False)
375+
tp_rank=tp_rank)
364376
else:
365377
assert shard_id in ("w1", "w3")
366378
expert_data.copy_(loaded_weight)
@@ -450,7 +462,7 @@ def weight_loader(self, param: torch.nn.Parameter,
450462
loaded_weight=loaded_weight,
451463
expert_data=expert_data,
452464
tp_rank=tp_rank,
453-
load_full_w2=True)
465+
load_full_w2=getattr(param, "load_full_w2", False))
454466
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
455467
self._load_per_tensor_weight_scale(shard_id=shard_id,
456468
param=param,
@@ -476,8 +488,7 @@ def weight_loader(self, param: torch.nn.Parameter,
476488
shard_dim=shard_dim,
477489
loaded_weight=loaded_weight,
478490
expert_data=expert_data,
479-
tp_rank=tp_rank,
480-
load_full_w2=False)
491+
tp_rank=tp_rank)
481492
return
482493

483494
@staticmethod

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
7979
hidden_size: int, intermediate_size: int,
8080
params_dtype: torch.dtype, **extra_weight_attrs):
8181

82+
# not needed by fp8
83+
extra_weight_attrs.pop("intermediate_full")
8284
params_dtype = torch.float8_e4m3fn
8385

8486
# WEIGHTS
@@ -269,12 +271,12 @@ def __init__(
269271

270272
def create_weights(self, layer: torch.nn.Module, num_experts: int,
271273
hidden_size: int, intermediate_size: int,
272-
intermediate_full: int, params_dtype: torch.dtype,
273-
**extra_weight_attrs):
274+
params_dtype: torch.dtype, **extra_weight_attrs):
274275

275276
# Will transpose the loaded weight along the
276277
# intermediate and hidden dim sizes. Will
277278
# shard for TP along the transposed dims
279+
intermediate_full = extra_weight_attrs.pop("intermediate_full")
278280
extra_weight_attrs.update({
279281
"is_transposed": True,
280282
"quant_method": self.strategy
@@ -297,15 +299,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
297299
layer.register_parameter("w2_weight_packed", w2_weight)
298300
set_weight_attrs(w2_weight, extra_weight_attrs)
299301

300-
self.is_k_full = (intermediate_full == intermediate_size)
301-
scales_size = (intermediate_full if self.actorder
302-
and self.group_size != -1 else intermediate_size)
302+
# In the case where we have actorder/g_idx,
303+
# we do not partition the w2 scales
304+
load_full_w2 = self.actorder and self.group_size != -1
305+
w2_scales_size = (intermediate_full
306+
if load_full_w2 else intermediate_size)
307+
# @eliza TODO: is this condition actually needed/is it doing anything?
308+
self.is_k_full = (not self.actorder) or (
309+
self.actorder and intermediate_size == intermediate_full)
303310

304311
if self.strategy == "channel":
305312
num_groups_w2 = num_groups_w13 = 1
306313
self.group_size = -1
307314
else:
308-
num_groups_w2 = scales_size // self.group_size
315+
num_groups_w2 = w2_scales_size // self.group_size
309316
num_groups_w13 = hidden_size // self.group_size
310317

311318
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
@@ -323,6 +330,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
323330
requires_grad=False)
324331
layer.register_parameter("w2_weight_scale", w2_scale)
325332
set_weight_attrs(w2_scale, extra_weight_attrs)
333+
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
326334

327335
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
328336
requires_grad=False)
@@ -432,6 +440,8 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
432440
num_experts = layer.w13_weight_g_idx.shape[0]
433441
device = layer.w13_weight_g_idx.device
434442

443+
# when running models with grouped act order,
444+
# resort g_idx values provided
435445
if self.actorder == "group":
436446
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
437447
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
@@ -552,5 +562,4 @@ def apply(
552562
sort_indices1=layer.w13_g_idx_sort_indices,
553563
sort_indices2=layer.w2_g_idx_sort_indices,
554564
num_bits=self.num_bits,
555-
is_k_full=self.is_k_full,
556-
)
565+
is_k_full=self.is_k_full)

0 commit comments

Comments
 (0)