@@ -79,6 +79,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
7979 hidden_size : int , intermediate_size : int ,
8080 params_dtype : torch .dtype , ** extra_weight_attrs ):
8181
82+ # not needed by fp8
83+ extra_weight_attrs .pop ("intermediate_full" )
8284 params_dtype = torch .float8_e4m3fn
8385
8486 # WEIGHTS
@@ -269,12 +271,12 @@ def __init__(
269271
270272 def create_weights (self , layer : torch .nn .Module , num_experts : int ,
271273 hidden_size : int , intermediate_size : int ,
272- intermediate_full : int , params_dtype : torch .dtype ,
273- ** extra_weight_attrs ):
274+ params_dtype : torch .dtype , ** extra_weight_attrs ):
274275
275276 # Will transpose the loaded weight along the
276277 # intermediate and hidden dim sizes. Will
277278 # shard for TP along the transposed dims
279+ intermediate_full = extra_weight_attrs .pop ("intermediate_full" )
278280 extra_weight_attrs .update ({
279281 "is_transposed" : True ,
280282 "quant_method" : self .strategy
@@ -297,15 +299,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
297299 layer .register_parameter ("w2_weight_packed" , w2_weight )
298300 set_weight_attrs (w2_weight , extra_weight_attrs )
299301
300- self .is_k_full = (intermediate_full == intermediate_size )
301- scales_size = (intermediate_full if self .actorder
302- and self .group_size != - 1 else intermediate_size )
302+ # In the case where we have actorder/g_idx,
303+ # we do not partition the w2 scales
304+ load_full_w2 = self .actorder and self .group_size != - 1
305+ w2_scales_size = (intermediate_full
306+ if load_full_w2 else intermediate_size )
307+ # @eliza TODO: is this condition actually needed/is it doing anything?
308+ self .is_k_full = (not self .actorder ) or (
309+ self .actorder and intermediate_size == intermediate_full )
303310
304311 if self .strategy == "channel" :
305312 num_groups_w2 = num_groups_w13 = 1
306313 self .group_size = - 1
307314 else :
308- num_groups_w2 = scales_size // self .group_size
315+ num_groups_w2 = w2_scales_size // self .group_size
309316 num_groups_w13 = hidden_size // self .group_size
310317
311318 w13_scale = torch .nn .Parameter (torch .ones (num_experts ,
@@ -323,6 +330,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
323330 requires_grad = False )
324331 layer .register_parameter ("w2_weight_scale" , w2_scale )
325332 set_weight_attrs (w2_scale , extra_weight_attrs )
333+ set_weight_attrs (w2_scale , {"load_full_w2" : load_full_w2 })
326334
327335 w2_weight_shape = torch .nn .Parameter (torch .empty (num_experts , 2 ),
328336 requires_grad = False )
@@ -432,6 +440,8 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
432440 num_experts = layer .w13_weight_g_idx .shape [0 ]
433441 device = layer .w13_weight_g_idx .device
434442
443+ # when running models with grouped act order,
444+ # resort g_idx values provided
435445 if self .actorder == "group" :
436446 w13_g_idx_sort_indices = torch .empty_like (layer .w13_weight_g_idx )
437447 w2_g_idx_sort_indices = torch .empty_like (layer .w2_weight_g_idx )
@@ -552,5 +562,4 @@ def apply(
552562 sort_indices1 = layer .w13_g_idx_sort_indices ,
553563 sort_indices2 = layer .w2_g_idx_sort_indices ,
554564 num_bits = self .num_bits ,
555- is_k_full = self .is_k_full ,
556- )
565+ is_k_full = self .is_k_full )
0 commit comments