@@ -161,7 +161,7 @@ std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
161161 //  map between from idx in shape to new_shape
162162 std::vector<int64_t > idx_map (shape.size (), -1 );
163163 for  (int  i = 0 , n = static_cast <int >(shape.size ()); i < n; ++i) {
164-  if  (shape[id ] != 1 ) {
164+  if  (shape[i ] != 1 ) {
165165 idx_map[i] = static_cast <int64_t >(new_shape.size ());
166166 new_shape.emplace_back (shape[i]);
167167 }
@@ -272,6 +272,139 @@ std::vector<std::shared_ptr<DimTrans>> GetDimTrans(
272272 return  ret_dim_trans;
273273}
274274
275+ std::vector<std::shared_ptr<DimTrans>> GetDimTransCoShard (
276+  const  std::shared_ptr<DimTrans> dim_trans,
277+  const  std::vector<int64_t >& input_shape,
278+  const  std::vector<int64_t >& mesh_shape,
279+  const  std::vector<std::vector<int64_t >>& input_dims_mapping,
280+  const  std::set<int64_t >& sharded_input_dims,
281+  std::vector<std::vector<bool >>* shardable,
282+  std::set<int64_t >* seen_dims) {
283+  DimTrans::Type type = dim_trans->type ();
284+  std::vector<std::shared_ptr<DimTrans>> ret_dim_trans;
285+ 
286+  if  (type == DimTrans::Type::INPUTDIM) {
287+  std::shared_ptr<InputDim> inputdim =
288+  std::dynamic_pointer_cast<InputDim>(dim_trans);
289+  int64_t  dim = inputdim->input_dim ();
290+  seen_dims->insert (dim);
291+ 
292+  if  (sharded_input_dims.count (dim) > 0 ) {
293+  ret_dim_trans.push_back (dim_trans);
294+  }
295+  } else  if  (type == DimTrans::Type::FLATTEN) {
296+  std::shared_ptr<Flatten> flatten =
297+  std::dynamic_pointer_cast<Flatten>(dim_trans);
298+  const  std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs ();
299+ 
300+  int64_t  nmesh = (*shardable)[0 ].size (); //  NOLINT
301+  int64_t  mesh_shape_prod = 1 ;
302+ 
303+  int  last_shard_idx = -1 ;
304+  int  first_shard_idx = -1 ;
305+  int64_t  first_sharded_shape = -1 ;
306+  for  (int  i = 0 , n = static_cast <int >(inputs.size ()); i < n; ++i) {
307+  std::shared_ptr<DimTrans> input = inputs[i];
308+  if  (input->type () == DimTrans::Type::INPUTDIM) {
309+  std::shared_ptr<InputDim> inputdim =
310+  std::dynamic_pointer_cast<InputDim>(input);
311+  if  (sharded_input_dims.count (inputdim->input_dim ()) > 0 ) {
312+  if  (first_shard_idx == -1 ) {
313+  first_shard_idx = i;
314+  first_sharded_shape = input_shape[inputdim->input_dim ()];
315+  }
316+  for  (const  auto & dim : input_dims_mapping[inputdim->input_dim ()]) {
317+  mesh_shape_prod *= mesh_shape[dim];
318+  }
319+  if  (first_sharded_shape % mesh_shape_prod == 0 ) {
320+  ret_dim_trans.push_back (inputdim);
321+  } else  {
322+  break ;
323+  }
324+  } else  {
325+  break ;
326+  }
327+  last_shard_idx = i;
328+  } else  {
329+  break ;
330+  }
331+  }
332+ 
333+  for  (int  i = last_shard_idx + 1 , n = static_cast <int >(inputs.size ()); i < n;
334+  i++) {
335+  std::shared_ptr<DimTrans> input = inputs[i];
336+  if  (input->type () == DimTrans::Type::INPUTDIM) {
337+  std::shared_ptr<InputDim> inputdim =
338+  std::dynamic_pointer_cast<InputDim>(input);
339+  (*shardable)[inputdim->input_dim ()].assign (nmesh, false );
340+  }
341+ 
342+  GetDimTransCoShard (input,
343+  input_shape,
344+  mesh_shape,
345+  input_dims_mapping,
346+  sharded_input_dims,
347+  shardable,
348+  seen_dims);
349+  }
350+  } else  if  (type == DimTrans::Type::SPLIT) {
351+  std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
352+  std::vector<std::shared_ptr<DimTrans>> dims =
353+  GetDimTransCoShard (split->input (),
354+  input_shape,
355+  mesh_shape,
356+  input_dims_mapping,
357+  sharded_input_dims,
358+  shardable,
359+  seen_dims);
360+  int64_t  ret_size = split->local_split_shape_value ();
361+ 
362+  if  (split->split_id () == 0 ) {
363+  int64_t  mesh_shape_prod = 1 ;
364+  int64_t  first_shard_idx = -1 ;
365+  int64_t  first_sharded_shape = -1 ;
366+  for  (const  auto & dim : dims) {
367+  PADDLE_ENFORCE_EQ (dim->type (),
368+  DimTrans::Type::INPUTDIM,
369+  common::errors::InvalidArgument (
370+  " The returned dim_trans must be INPUTDIM." 
371+  std::shared_ptr<InputDim> inputdim =
372+  std::dynamic_pointer_cast<InputDim>(dim);
373+  int64_t  nmesh = static_cast <int64_t >(mesh_shape.size ());
374+  int64_t  input_axis = inputdim->input_dim ();
375+ 
376+  //  Check whether the sharded dim can be sharded on
377+  //  each mesh dimension. The dimension should be
378+  //  divisible by the mesh size that it is sharded on
379+  for  (int64_t  imesh = 0 ; imesh < nmesh; imesh++) {
380+  (*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0 );
381+  }
382+ 
383+  if  (first_shard_idx == -1 ) {
384+  first_shard_idx = input_axis;
385+  first_sharded_shape = input_shape[input_axis];
386+  }
387+ 
388+  if  (sharded_input_dims.count (input_axis) > 0 ) {
389+  for  (const  auto & dim : input_dims_mapping[input_axis]) {
390+  mesh_shape_prod *= mesh_shape[dim];
391+  }
392+  if  ((ret_size % mesh_shape_prod == 0 ) &&
393+  (first_sharded_shape % mesh_shape_prod == 0 )) {
394+  ret_dim_trans.push_back (dim);
395+  } else  {
396+  break ;
397+  }
398+  } else  {
399+  break ;
400+  }
401+  }
402+  }
403+  } else  if  (type == DimTrans::Type::SINGLETON) {
404+  }
405+  return  ret_dim_trans;
406+ }
407+ 
275408void  GetUsedInputDim (const  std::shared_ptr<DimTrans> dim_trans,
276409 std::set<int64_t >* seen_dims) {
277410 if  (dim_trans->type () == DimTrans::Type::INPUTDIM) {
@@ -311,6 +444,27 @@ InferFromDimTrans(const DistMetaTensor& input_spec,
311444 return  InferFromDimTrans (input_spec, input_shape, dim_trans);
312445}
313446
447+ std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
448+ InferFromDimTransCoShard (
449+  const  DistMetaTensor& input_spec,
450+  const  std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
451+  auto  input_shape = phi::vectorize (input_spec.dims ());
452+  //  deal with reshape xshape in dynamic
453+  if  (input_shape[0 ] == 0  &&
454+  input_shape.size () !=
455+  input_spec.dist_attr ().multi_dims_mapping ().size ()) {
456+  input_shape.erase (input_shape.begin ());
457+  }
458+  PADDLE_ENFORCE_EQ (input_shape.size (),
459+  input_spec.dist_attr ().multi_dims_mapping ().size (),
460+  common::errors::InvalidArgument (
461+  " The Tensor X's rank [%d] and X's " 
462+  " dims_mapping size [%d] are not matched." 
463+  input_shape.size (),
464+  input_spec.dist_attr ().multi_dims_mapping ().size ()));
465+  return  InferFromDimTransCoShard (input_spec, input_shape, dim_trans);
466+ }
467+ 
314468std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
315469InferFromDimTrans (const  DistMetaTensor& input,
316470 const  std::vector<int64_t >& input_shape,
@@ -400,4 +554,105 @@ InferFromDimTrans(const DistMetaTensor& input,
400554 return  {new_input_dims_mapping, out_dims_mapping};
401555}
402556
557+ std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
558+ InferFromDimTransCoShard (
559+  const  DistMetaTensor& input,
560+  const  std::vector<int64_t >& input_shape,
561+  const  std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
562+  const  std::vector<std::vector<int64_t >>& input_dims_mapping =
563+  input.dist_attr ().multi_dims_mapping ();
564+  const  ProcessMesh& mesh = input.dist_attr ().process_mesh ();
565+  const  std::vector<int64_t >& mesh_shape = mesh.shape ();
566+ 
567+  std::set<int64_t > sharded_input_dims;
568+  for  (int64_t  i = 0 , n = static_cast <int64_t >(input_dims_mapping.size ());
569+  i < n;
570+  ++i) {
571+  if  (std::any_of (input_dims_mapping[i].begin (),
572+  input_dims_mapping[i].end (),
573+  [](int64_t  dim) { return  dim > -1 ; })) {
574+  sharded_input_dims.insert (i);
575+  }
576+  }
577+  int64_t  ndim = static_cast <int64_t >(input_shape.size ());
578+  int64_t  nmesh = static_cast <int64_t >(mesh_shape.size ());
579+  std::vector<std::vector<bool >> shardable (ndim,
580+  std::vector<bool >(nmesh, true ));
581+ 
582+  std::set<int64_t > seen_input_dims;
583+  for  (const  std::shared_ptr<DimTrans>& trans : dim_trans) {
584+  GetUsedInputDim (trans, &seen_input_dims);
585+  }
586+ 
587+  for  (int64_t  idim = 0 ; idim < ndim; idim++) {
588+  bool  seen = seen_input_dims.count (idim);
589+  if  (!seen) {
590+  shardable[idim].assign (nmesh, seen);
591+  }
592+  }
593+ 
594+  //  get the map from sharded input dimensions to output dimensions.
595+  //  key is src dim, value is dst dim.
596+  std::vector<int64_t > dim_map_src2tgt (ndim, -1 );
597+  std::unordered_map<int , std::vector<int >> dim_map_dst2src;
598+  for  (int64_t  i = 0 , n = static_cast <int64_t >(dim_trans.size ()); i < n; i++) {
599+  std::vector<std::shared_ptr<DimTrans>> dims =
600+  GetDimTransCoShard (dim_trans[i],
601+  input_shape,
602+  mesh_shape,
603+  input_dims_mapping,
604+  sharded_input_dims,
605+  &shardable,
606+  &seen_input_dims);
607+  for  (auto  dim : dims) {
608+  if  (dim->type () == DimTrans::Type::INPUTDIM) {
609+  std::shared_ptr<InputDim> inputdim =
610+  std::dynamic_pointer_cast<InputDim>(dim);
611+  dim_map_src2tgt[inputdim->input_dim ()] = i;
612+  dim_map_dst2src[i].push_back (inputdim->input_dim ());
613+  }
614+  }
615+  }
616+ 
617+  std::vector<std::vector<int64_t >> out_dims_mapping (dim_trans.size ());
618+  std::vector<std::vector<int64_t >> new_input_dims_mapping (
619+  input_dims_mapping.size ());
620+  for  (size_t  i = 0 ; i < input_dims_mapping.size (); i++) {
621+  if  (std::any_of (input_dims_mapping[i].begin (),
622+  input_dims_mapping[i].end (),
623+  [](int64_t  dim) { return  dim >= 0 ; })) {
624+  new_input_dims_mapping[i] = input_dims_mapping[i];
625+  }
626+  }
627+ 
628+  //  set output dims mapping with corresponding input dimensions.
629+  //  if one input dimension is sharded on a unshardable mesh after
630+  //  splitting, we need to make it replicated.
631+  for  (int64_t  i = 0 ; i < ndim; i++) {
632+  std::vector<int64_t > mesh_dims = input_dims_mapping[i];
633+  for  (int64_t  mesh_dim : mesh_dims) {
634+  if  (mesh_dim > -1  && shardable[i][mesh_dim] && dim_map_src2tgt[i] > -1 ) {
635+  int  dst_dim = dim_map_src2tgt[i];
636+  out_dims_mapping[dst_dim].push_back (mesh_dim);
637+ 
638+  auto  src_dim = dim_map_dst2src[dst_dim];
639+  auto  min_dim = std::min_element (src_dim.begin (), src_dim.end ());
640+  if  (*min_dim == i) {
641+  new_input_dims_mapping[*min_dim].push_back (mesh_dim);
642+  continue ;
643+  }
644+  new_input_dims_mapping[*min_dim].insert (
645+  new_input_dims_mapping[*min_dim].end (),
646+  new_input_dims_mapping[i].begin (),
647+  new_input_dims_mapping[i].end ());
648+  new_input_dims_mapping[i] = {};
649+  } else  {
650+  new_input_dims_mapping[i] = {};
651+  }
652+  }
653+  }
654+ 
655+  return  {new_input_dims_mapping, out_dims_mapping};
656+ }
657+ 
403658} //  namespace phi::distributed
0 commit comments