@@ -78,7 +78,12 @@ def _all_to_all_in_static_mode(
7878 if isinstance (in_tensor_or_tensor_list , list ):
7979 if len (in_tensor_or_tensor_list ) == 0 :
8080 raise RuntimeError ("The input tensor_list should not be empty." )
81- in_tensor = paddle .concat (in_tensor_or_tensor_list , axis = 0 )
81+ # 0D use stack/unstack while others use concat/split
82+ if len (in_tensor_or_tensor_list [0 ].shape ) == 0 :
83+ in_tensor = paddle .stack (in_tensor_or_tensor_list , axis = 0 )
84+ else :
85+ in_tensor = paddle .concat (in_tensor_or_tensor_list , axis = 0 )
86+
8287 out_tensor = out_tensor_or_tensor_list
8388 if isinstance (out_tensor_or_tensor_list , list ):
8489 if len (out_tensor_or_tensor_list ) != 0 :
@@ -110,7 +115,13 @@ def _all_to_all_in_static_mode(
110115 if isinstance (out_tensor_or_tensor_list , list ):
111116 if not sync_op :
112117 dist .wait (out_tensor , use_calc_stream = False )
113- out_tensor_or_tensor_list .extend (paddle .split (out_tensor , nranks , 0 ))
118+ # 0D use stack/unstack while others use concat/split
119+ if len (in_tensor_or_tensor_list [0 ].shape ) == 0 :
120+ out_tensor_or_tensor_list .extend (paddle .unstack (out_tensor , 0 ))
121+ else :
122+ out_tensor_or_tensor_list .extend (
123+ paddle .split (out_tensor , nranks , 0 )
124+ )
114125
115126 return None
116127
0 commit comments