@@ -194,7 +194,8 @@ def workspace_shapes(
194194 N : int ,
195195 K : int ,
196196 topk : int ,
197- num_experts : int ,
197+ global_num_experts : int ,
198+ local_num_experts : int ,
198199 ) -> tuple [tuple [int , ...], tuple [int , ...], tuple [int , ...], torch .dtype ]:
199200 """
200201 Compute the shapes for the temporary and final outputs of the two gemms
@@ -372,8 +373,9 @@ def forward(
372373 a1 = hidden_states
373374 output = a1 if inplace else torch .zeros_like (a1 )
374375
376+ local_num_experts = w1 .size (0 )
375377 if global_num_experts == - 1 :
376- global_num_experts = w1 . size ( 0 )
378+ global_num_experts = local_num_experts
377379
378380 (a1q , a1q_scale , expert_num_tokens , _expert_topk_ids ,
379381 _expert_topk_weights ) = self .prepare_finalize .prepare (
@@ -408,16 +410,19 @@ def forward(
408410 if num_chunks == 1 :
409411 (workspace13_shape , workspace2_shape , fused_out_shape ,
410412 workspace_dtype ) = self .fused_experts .workspace_shapes (
411- a1 , a1q , M , N , K , top_k , global_num_experts )
413+ a1 , a1q , M , N , K , top_k , global_num_experts ,
414+ local_num_experts )
412415 else :
413416 # Use the full M to get the final output shape.
414417 _ , _ , fused_out_shape , _ = (
415418 self .fused_experts .workspace_shapes (
416- a1 , a1q , M , N , K , top_k , global_num_experts ))
419+ a1 , a1q , M , N , K , top_k , global_num_experts ,
420+ local_num_experts ))
417421 # Use the CHUNK_SIZE to get the workspace shapes.
418422 workspace13_shape , workspace2_shape , _ , workspace_dtype = (
419423 self .fused_experts .workspace_shapes (
420- a1 , a1q , CHUNK_SIZE , N , K , top_k , global_num_experts ))
424+ a1 , a1q , CHUNK_SIZE , N , K , top_k , global_num_experts ,
425+ local_num_experts ))
421426
422427 # We can reuse the memory between cache1 and cache3 because by the
423428 # time we need cache3, we're done with cache1.
0 commit comments