Skip to content

【开源任务】算子切分推导规则开发,支持更多模型使用自动并行,简化更多用户的分布式开发成本 #72415

@Glencsa

Description

@Glencsa

一、任务背景

1.1 自动并行原理,自动并行与算子切分推导规则的关系

飞桨3.0发布了动静统一的自动并行功能,目的是简化分布式的开发,仅需要用户在神经网络中某些关键位置配置或标注分布式状态,深度学习框架就自动推断出网络中剩余其它部分的分布式状态并执行。自动推断在保证计算结果正确性的前提下,需要实现最优的通信和计算。
具体在框架实现中,当使用自动并行的方式执行1个神经网络时,网络中的输入数据(Tensor)和训练权重都会按照用户的配置和标注,初始化出分布式状态。分布式状态主要包括2部分:设备组织(通过paddle.distributed.ProcessMesh构造) 和 设备对数据(Tensor)的切分状态(通过paddle.distributed.Placement构造)。计算时,会按照神经网络定义执行某些算子的计算,即对于神经网络中的每个算子,只要定义1个规则(下面称 切分推导规则),就可以自动推断出此算子计算过程中需要的通信,进而逐个算子自动推断出整个网络所需要的通信,从而实现最优的通信和计算。
切分推导规则和算子计算逻辑强相关,其实现需要根据算子的计算逻辑、设备对输入数据(Tensor)的切分状态 来推断确定 最优的输入/输出数据(Tensor)的切分状态。如果某个输入/输出的切分状态 和 推断出的最优切分状态 不一样,则框架就能自动推断出在相应位置上所需要的通信。
Image
因此,为了让更多模型能够使用自动并行,简化更多用户的分布式开发成本,我们需要开发每个算子的切分推导规则。

1.2 算子切分推导规则介绍

自动并行中用户只标记了组网中部分 Tensors (Op)的切分状态(DistAttr),模型组网中仅有部分Tensors 有分布式属性(DistAttr)。在自动并行实际执行过前,模型组网中的所有Tensors(Ops) 都需要有一个确定的切分状态,每个Local 设备(进程)需要根据分布式属性信息判断在执行过程中当前设备(进程)需要的通信和切分操作。 算子切分推导规则的目标就是在利用该算子进行计算的同时,根据输入组网中的部分切分状态,推导补全整个组网的切分状态。

理想情况下,每个Op 都会有一个专门的切分推导规则。

自动并行中,每个算子的执行逻辑如下:

Image

数据的切分信息分为3种:

  • shard:在指定的张量维度上对张量进行切分。
  • replicate:跨设备复制tensor,每个rank得到完全相同的tensor
  • partial:一种张量,在不同设备上具有相同的形状,但在每个设备上只有部分值。 它可以进一步规约操作(即sum/min/max)以获得分布式张量。 这通常用作中间表示。

合法切分状态的推导

  • 对于一个孤立的 Tensor,我们可以随意设置它的在集群中的切分状态。 但是对于一个算子其输入输出Tensor 的切分状态不能是任意的。
  • 基于算子自身的运算逻辑,给定一个输入(输出)的切分切状态,其输出(输入) 合法的切分状态是一个有限的集合。 基于用户部分切分标记,如何推导合法的切分状态。
  • 合法的定义为:Tensor 切分状态(shape,partial)满足Op 的运算要求,并能获得正确的(local)计算结果。

Matmul算子切分推导规则举例

import paddle.distributed as dist mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
X W Y = XW
[Replicate, Replicate] [Replicate, Replicate] [Replicate, Replicate]
[Replicate, Replicate] [Replicate, 'x'] [Replicate, 'x']
['y', Replicate] [Replicate, Replicate] ['y', Replicate]
[Replicate, 'x'] ['y', Replicate] [Replicate, Replicate]
['y', Replicate] [Replicate, 'x'] ['y', 'x']
['y', 'x'] ['x', Replicate] ['y', Replicate]

不同算子的分布式规则只和算子逻辑相关。

1.3 算子切分推导规则开发原则

具体开发规则可以参考切分推导规则参考文档

半自动框架为步骤中的公共逻辑提供 公共 utils 函数,在Paddle/paddle/phi/infermeta/spmd_rules目录下,开发者只需要实现与 Op 自身运算法则相关的逻辑。
Image

开发者需要在该目录下创建算子同名的cpp文件(如:argmax.cc, argmax.h),在文件当中为该算子开发切分推导规则(若该算子有对应的反向算子,则要求在该文件中连同该反向算子的切分推导规则一同开发),开发完成以后,需在该文件夹中的rule.cc文件当中完成算子切分推导规则注册,即可完成对该算子的切分推导规则开发。此外,开发者需要在 Paddle/test/cpp/auto_parallel 文件夹下增加该算子切分推导规则的单元测试,完成新增代码的单测覆盖,并测试该算子切分推导规则的正确性。

二、任务详情,需要开发的算子列表

本期需要增加切分推导规则的算子如下,整体进展:

序号 算子名称 队伍名称/状态/PR 难度
1 topk @ooooo-create #72499
0.5×⭐️
2 cummax @ooooo-create #72720
0.5×⭐️
3 cummin @ooooo-create #72720
0.5×⭐️
4 batch_norm @Glencsa #72918
0.5×⭐️
5 mean_all @ooooo-create #72479
0.5×⭐️
6 unique @ooooo-create #72824
0.5×⭐️
7 expand_as @ooooo-create #72845 #73107
0.5×⭐️
8 log_softmax @ooooo-create #72720
0.5×⭐️
9 group_norm @Glencsa #72946
0.5×⭐️
10 index_select @ooooo-create #72727
0.5×⭐️
11 instance_norm @Glencsa #72938
0.5×⭐️
12 label_smooth @ooooo-create #72845
0.5×⭐️
13 sync_batch_norm @Glencsa #72918
0.5×⭐️
14 roll @ooooo-create #72740
0.5×⭐️
15 index_put @Juggler-YAN #73155
@ttuuuuyyyj
@Glencsa #73486
@ghost
0.5×⭐️
16 depthwise_con2d @NKNaN #73134
17 conv2d_transpose @NKNaN #73188
18 conv3d @NKNaN #72882
19 roi_align @ooooo-create #72925

⭐️ 提交PR 模版 ⭐️:

// ------- PR 标题 --------

[Auto Parallel] Add spmd rule No.xxx for xxx and xxx_grad ops. 

// ------- PR 内容 --------

PR Category Auto Parallel PR types New features Description 为xxx和xxx_grad算子增加切分推导规则

三、参考指南

建议

  • 开发者需要先看懂Paddle/paddle/phi/infermeta/spmd_rules目录下的utils.cc基础函数的一些使用,以及已算子命名的cpp文件(argmax.cc, argmax.h)的代码逻辑,有助于帮助开发者对新算子切分推导规则开发的快速入门。
  • 开发者在写单测代码时,应该考虑算子切分推导尽可能多的切分情况。

题目讲解见录屏文件:https://meeting.tencent.com/crm/l59EWmRZc4 (00:00:00~00:15:30)

看板信息

任务方向 任务数量 提交作品 / 任务认领 提交率 完成 完成率
算子切分推导规则开发 19 19 / 19 100.0% 19 100.0%

统计信息

排名不分先后 @ooooo-create (11) @Glencsa (5) @NKNaN (3)

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions