@@ -25,7 +25,7 @@ def pplx_garden_hidden_dim_scale(
2525 quant_dtype : torch .dtype | str | None ,
2626 per_act_token_quant : bool ,
2727 block_shape : list [int ] | None ,
28- ) -> int :
28+ ) -> int | None :
2929 # For blocked per token: set to
3030 # ceil_div(hidden_dim, block_size) * sizeof(float32)
3131 # For per-token: set to 4 * sizeof(float32) (x4 for alignment)
@@ -37,16 +37,16 @@ def pplx_garden_hidden_dim_scale(
3737 if per_act_token_quant :
3838 # per-token (M x 1)
3939 assert block_shape is None
40- hidden_dim_scale = 1
40+ hidden_dim_scale = 16
4141 elif block_shape is not None :
4242 # per-group (M x K_tiles)
4343 block_size = block_shape [1 ]
4444 hidden_dim_scale = cdiv (hidden_dim , block_size )
4545 else :
4646 # per-tensor (1 x 1)
47- hidden_dim_scale = 1
47+ hidden_dim_scale = 16
4848 else :
49- hidden_dim_scale = 0
49+ hidden_dim_scale = None # 1?
5050
5151 return hidden_dim_scale
5252
@@ -190,7 +190,7 @@ def prepare_async(
190190 expert_x_scale_shape = (
191191 self .num_local_experts ,
192192 expert_x .size (1 ),
193- round_up (final_dim , 4 ), # round up for alignment
193+ round_up (final_dim , 16 ), # round up for alignment
194194 )
195195
196196 expert_x_scale = torch .empty (
@@ -203,7 +203,11 @@ def prepare_async(
203203 # There's not much point setting this unless it is != indices.size(0)
204204 bound_m : torch .Tensor | None = None
205205
206- logger .debug ("PPLX_GARDEN dispatch send %s" , expert_x .shape )
206+ logger .debug (
207+ "PPLX_GARDEN dispatch send %s, %s" ,
208+ expert_x .shape ,
209+ expert_x_scale .shape if expert_x_scale is not None else None ,
210+ )
207211
208212 self .a2a .dispatch (
209213 out_expert_num_tokens = expert_num_tokens ,
@@ -269,7 +273,8 @@ def _receiver(
269273 "PPLX_GARDEN receive X_SCALE %s" ,
270274 expert_x_scale .shape if expert_x_scale is not None else None ,
271275 )
272- logger .debug ("PPLX_GARDEN receive META %s" , expert_tokens_meta )
276+ logger .debug ("PPLX_GARDEN receive num_tokens %s" , expert_num_tokens .shape )
277+ # logger.debug("PPLX_GARDEN receive META %s", expert_tokens_meta)
273278
274279 return expert_x , expert_x_scale , expert_tokens_meta , None , None
275280
@@ -332,11 +337,13 @@ def finalize_async(
332337
333338 logger .debug ("PPLX_GARDEN combine send" )
334339
340+ hidden_dim = output .size (1 )
341+
335342 self .a2a .combine (
336343 out_tokens = output ,
337344 indices = topk_ids_u32 ,
338345 weights = topk_weights ,
339- expert_y = fused_expert_output ,
346+ expert_y = fused_expert_output . view ( - 1 , hidden_dim ) ,
340347 bound_m = bound_m ,
341348 do_send = True ,
342349 do_recv = False ,
@@ -349,7 +356,7 @@ def finalize_async(
349356 out_tokens = output ,
350357 indices = topk_ids_u32 ,
351358 weights = topk_weights ,
352- expert_y = fused_expert_output ,
359+ expert_y = fused_expert_output . view ( - 1 , hidden_dim ) ,
353360 bound_m = bound_m ,
354361 do_send = False ,
355362 do_recv = True ,
0 commit comments