Skip to content

Conversation

@liufengwei0103
Copy link
Contributor

@liufengwei0103 liufengwei0103 commented Jun 10, 2025

PR Category

Auto Parallel

PR Types

New features

Description

main features:
1.enhance expression of placement shard that shard the same tensor dim by many mesh dim by adding co_shard_order to support to merge many sharded tensor dim in reshape.
2.enhance reshard api to express that rearrange data before sharding tensor to support to reshard fused qkv in dist env.

main changes:
1.upgrade dims_mapping to be type of vector of vector
2.refactor nd_mesh reshard transform
3.add co_shard_order and split_factor in shard placement
4.add dims_mapping proxy to back compatible old spmd rule during transitional phase between dims_mapping of vector and new dims_mapping of vector of vector.

usage:
get a co_shard tensor

import paddle import paddle.distributed as dist a = paddle.to_tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) placements = [ dist.Shard(0, co_shard_order=0), dist.Shard(0, co_shard_order=1), ] b = dist.shard_tensor(a, mesh, placements) print(b.placements) # [Shard(0, shard_order=0), Shard(0, shard_order=1)] print(b._local_value()) # rank0 [[1, 2]], rank1 [[3, 4]], rank2 [[5, 6]], rank3 [[7, 8]] 

co shard in reshape

import paddle import paddle.distributed as dist a = paddle.to_tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype='float32') mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) placements = [dist.Shard(0), dist.Shard(1)] input = dist.shard_tensor(a, mesh, placements) out = paddle.reshape(input, [-1]) print(out.placements) # [Shard(0, shard_order=0), Shard(0, shard_order=1)] 

rearrange data before sharding

import paddle import paddle.distributed as dist mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) a = paddle.to_tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) placements = [dist.Shard(0, split_factor=2), dist.Replicate()] b = dist.shard_tensor(a, mesh, placements) print(b.placements) # [Shard(0, split_factor=2), Replicate()] print(b._local_value()) # rank0 rank1 [[1, 2], [5, 6]] , ran2 rank3 [[3, 4], [7, 8]] 

More use cases can be seen in the test cases.

Pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Jun 10, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@liufengwei0103 liufengwei0103 changed the title support to shard on the same tensor dim by many mesh dim, only dynami… support to shard on the same tensor dim by many mesh dim, only dynamic graph Jun 10, 2025
@liufengwei0103 liufengwei0103 marked this pull request as ready for review June 11, 2025 06:17
@liufengwei0103 liufengwei0103 marked this pull request as draft June 18, 2025 23:27
@liufengwei0103 liufengwei0103 marked this pull request as ready for review June 18, 2025 23:27
@jeff41404
Copy link
Contributor

The results of the three example codes in the Description above also need to be explained through print or comments to make it easier for others to understand

Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续需要补充一下用户文档,包括如何使用以及如何添加spmd rules

zhiqiu
zhiqiu previously approved these changes Jun 19, 2025
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@liufengwei0103
Copy link
Contributor Author

The results of the three example codes in the Description above also need to be explained through print or comments to make it easier for others to understand

done

Copy link
Contributor

@From00 From00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾 for pybind API without type annotations

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@From00 From00 merged commit 2327fff into PaddlePaddle:develop Jun 21, 2025
49 of 52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

8 participants