- Notifications
You must be signed in to change notification settings - Fork 5.9k
Description
一、需求背景
飞桨正在构建一套新的IR体系.在新IR下飞桨基于动态图的更规范的算子定义(ops.yaml、legacy_ops.yaml)生成了新IR体系下的算子.在新的IR体系下仍然需要保证旧IR的兼容性.为此飞桨提供了ProgramTranslator(相关代码位于paddle/fluid/ir_adaptor/translator/),它可以将旧IR表示下的计算图翻译为新IR下的计算图.目前,ProgramTranslator的核心工作是完成单个OP的翻译.也就是将旧IR下定义的OP(一般定义在paddle/fluid/operators文件夹下)翻译为新IR下定义的算子.但是,ProgramTranslator在翻译单个OP时会遇到下述两个问题:
- 有相当一部分的算子在新IR下并没有定义导致
ProgramTranslator翻译该OP时无法得到新IR下对应的OP定义. - 新旧IR下有相当一部分算子的定义是不一致的,
ProgramTranslator在翻译这部分算子时通用的翻译方案并不适合这些算子的翻译,我们需要单独定义它们的转换工作.
例如,对于dpsgd算子,它对应得单测文件是test/legacy_test/test_dpsgd_op.py,我们在新IR下执行这个单测可以看到报错信息如下:
Op dpsgd should have corresponding OpInfo pd_op.dpsgd这条错误提示主要是在说得不到dpsgd这个算子在新IR定义的OpInfo.这是由于新IR下没有定义dpsgd这个算子造成的.我们需要补充dpsgd算子在新IR下的定义.
修复以下Op单测在PIR下测试成功:
| 序号 | 单测 | 认领人/状态/PR号 |
|---|---|---|
| 1 | test_fake_quantize_op | @cmcamdy |
| 2 | test_empty_op | @longranger2 |
| 3 | test_matrix_rank_op | @xingmingyyj |
| 4 | test_sgd_op_bf16 | @xingmingyyj |
| 5 | test_tril_triu_op | @xingmingyyj |
| 6 | test_tdm_sampler_op | @xingmingyyj |
| 7 | test_activation_op | @xingmingyyj |
| 8 | test_shuffle_batch_op | @xingmingyyj |
| 9 | @xingmingyyj | |
| 10 | test_row_conv_op | @xingmingyyj |
| 11 | test_retinanet_detection_output | |
| 12 | test_partial_sum_op | @cmcamdy |
| 13 | test_partial_concat_op | @cmcamdy |
| 14 | test_nce | @xingmingyyj |
| 15 | test_match_matrix_tensor_op | @xingmingyyj |
| 16 | test_lookup_table_v2_bf16_op | @xingmingyyj |
| 17 | test_momentum_op | @xingmingyyj |
| 18 | test_identity_loss_op | |
| 19 | test_ftrl_op | @xingmingyyj |
| 20 | test_fake_dequantize_op | |
| 21 | test_data_norm_op | @Eacient |
| 22 | test_elementwise_sub_op | @Eddie-Wang1120 |
| 23 | test_distribute_fpn_proposals_op | @xingmingyyj |
| 24 | test_lamb_op | |
| 25 | test_quant_linear_op | |
| 26 | test_fused_token_prune_op | @Eddie-Wang1120 |
| 27 | test_softmax_mask_fuse_op | @Eddie-Wang1120 |
| 28 | test_fused_adam_op | @Eddie-Wang1120 |
| 29 | test_coalesce_tensor_op | @Eddie-Wang1120 |
| 30 | test_assign_pos_op | @Eddie-Wang1120 |
| 31 | test_number_count_op | @DrRyanHuang |
| 32 | test_bilateral_slice_op | @MayYouBeProsperous |
| 33 | test_fused_conv2d_add_act_op | @MayYouBeProsperous |
| 34 | test_rank_attention_op | @austin-00 |
| 35 | test_batch_fc_op | @austin-00 @Dmovic |
| 36 | test_adam_op | |
| 37 | test_stft_op | |
| 38 | test_semi_auto_parallel_c_cross_entropy | @xingmingyyj |
| 39 | test_c_reduce_min_translate | @xingmingyyj |
PR提交模板
- PR标题
【PIR OpTest Fix No.1】 fix test_fake_quantize_op- PR内容
### PR types Others ### PR changes Others ### Description PIR Op单测修复 修复单测 `test_fake_quantize_op` 修复后打开`FLAGS_enable_pir_in_executor`单测是否通过:否 报错信息:.......认领方式
请大家以 comment 的形式认领任务,如:
【报名】:1、3、12-13 多个任务之间需要使用中文顿号分隔,报名多个连续任务可用横线表示,如 2-5
PR 提交格式:在 PR 的标题中以 【PIR OpTest Fix No.xxx】 开头,注明任务编号
看板信息
| 任务方向 | 任务数量 | 提交作品 / 任务认领 | 提交率 | 完成 | 完成率 |
|---|---|---|---|---|---|
| 快乐开源 | 39 | 32 / 32 | 82.05% | 27 | 69.23% |
统计信息
排名不分先后 @longranger2 (1) @xingmingyyj (14) @cmcamdy (2) @Eddie-Wang1120 (5) @DrRyanHuang (1) @MayYouBeProsperous (2) @austin-00 (1) @Dmovic (1)
二、Tutorial
以dpsgd为例,下面是在新IR下补充定义dpsgd的流程.
2.1 为待修复OP在新Ir下补充定义
首先,可以看到dpsgd这个算子在旧IR下的定义.
class DpsgdOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { //infer shape } phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Param"),ctx.GetPlace()); } }; class DpsgdOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Param", "(Tensor) Input parameter"); AddInput("Grad", "(Tensor) Input gradient"); AddInput("LearningRate", "(Tensor) Learning rate"); AddOutput("ParamOut", "(Tensor) Output parameter"); AddAttr<float>("clip", "(float, default 0.9) " "Exponential decay rate for the " "1st moment estimates.") .SetDefault(10.0f); AddAttr<float>("batch_size", "(float, default 0.999) " "exponential decay rate for the weighted " "infinity norm estimates.") .SetDefault(16.0f); AddAttr<float>("sigma", "(float, default 1.0e-8) " "Constant for numerical stability") .SetDefault(1.0f); AddAttr<int>( "seed", "(int, default 0)" .SetDefault(0); } }; PD_REGISTER_STRUCT_KERNEL( dpsgd, CPU, ALL_LAYOUT, ops::DpsgdOpKernel, float, double) {}可以看到这里它有名为Param, Grad, LearningRate三个输入,一个名为ParamOut的输出,四个名为clip batch_size, sigma, seed的参数.
以及注册的kernel名称和GetExpectedKernelType信息.
我们可以在paddle/fluid/pir/dialect/operator/ir/ops.yaml文件中补充定义
- op: dpsgd args: (Tensor param, Tensor grad, Tensor learning_rate, float clip = 10.0f, float batch_size = 16.0f, float sigma = 1.0f, int seed = 0) output: Tensor(param_out) infer_meta: func: DpsgdInferMeta kernel: func: dpsgd data_type: paramyaml中前向算子的配置规则如下,下表引自文档开发C++算子的3.1 算子 Yaml 文件配置
| 配置项 | 配置内容及规则 |
|---|---|
| api | 算子名称,与该算子 Python API 函数名相同(命名方式为:全小写+下划线) |
| args | 算子输入参数,与该算子 Python API 函数的输入参数对应(当前支持的输入数据类型包括:Tensor, Tensor[], float, double, bool, int, int64_t, int[], int64_t[], str, Place, DataType, DataLayout, IntArray, Scalar)。我们一般称这里 Tensor 类型的参数为 Input(输入),非 Tensor 类型的参数为 Attribute(属性) 注:Tensor[]表示 Tensor 数组;IntArray 为 int 类型数组,主要用于表示 shape,index 和 axes 等类型数据,可以直接使用 Tensor 或者普通整型数组构造,目前仍在测试阶段,如非必要暂不建议使用;Scalar 表示标量,可以支持不同的普通数据类型 |
| output | 算子输出类型(目前支持 Tensor 和 Tensor[]类型),多个输出间用逗号“,”分隔开。可以使用”()”选择性标记输入的名字,如未标记默认为'out' 注:当返回类型为 Tensor[]时,由于数组的 size 要在 kernel 执行前推导完成,所以需要在 Tensor[]后的'{}'内通过表达式指定返回数组的 size,如:Tensor[](out){input.size()} |
| infer_meta | InferMeta 函数负责根据输入变量推断返回 Tensor 的维度与类型,这里是对算子使用的 InferMeta 函数进行配置 |
| infer_meta:func | 调用的 InferMeta 函数,这里 trace 调用的是 TraceInferMeta 函数 |
| infer_meta:param | InferMeta 函数的输入参数,可以对 args 中的参数进行选择传入,未配置则默认传入 args 中的所有参数。示例中未配置本项,所以传入的参数为[x, offset, axis1, axis2]。output 项中的参数作为输出无需配置会自动传入 InferMeta 函数中 |
| kernel | 算子的计算 Kernel 配置 |
| kernel:func | 算子对应 kernel 函数的注册名 |
| kernel:param | kernel 函数的输入参数,配置规则与 InferMeta 函数的 param 配置项相同 |
| kernel:data_type | 根据指定参数推导调用 kernel 的 data_type(对应 kernel 函数的模板参数'T'),默认不进行配置,会根据输入 Tensor 自动进行推导。如果 kernel 的 data_type 类型由某个输入参数(Tensor 或者 DataType 参数),需要将该参数的变量名填入该项。示例中未配置则 kernel 的 data_type 由输入变量'x'决定 |
| kernel:backend | 根据指定参数来选择调用 kernel 的 Backend(Kernel 执行的具体设备,如 CPU、GPU 等),默认不进行配置,会根据输入 Tensor 自动进行推导。如果 kernel 执行的 backend 类型由某个输入参数(Tensor 或者 Backend 参数)决定,需要将该参数的变量名填入该项。示例中未配置则 kernel 执行的 Backend 与输入变量'x'的 Backend 相同 |
| backward | 算子对应的反向算子名称,如果没有反向则不需要配置,示例中 trace 算子的反向为 trace_grad |
| 特殊配置项(目前特殊配置项还处于不稳定阶段,后续可能会有调整更新) | |
| optional | 指定输入 Tensor 为可选输入,用法可参考 dropout 中 seed_tensor(python/paddle/utils/code_gen/legacy_ops.yaml 中) |
| inplace | 算子对指定的输入做原位处理并作为输出结果返回,使用格式:(x -> out),具体用法可参考 relu 算子 特殊规则:如果 api 中算子名称有'_'后缀则只生成支持 inplace 功能的接口,如果算子名称没有'_'后缀,则会同时生成支持 inplace 操作的接口(自动添加'_'后缀)和不支持 inplace 的普通接口共两套接口 |
| view | 与 inplace 机制类似,区别在于 view 模式返回的结果只是与输入共享内存,并不是输入 Tensor 变量本身,使用格式:(x -> out),具体用法可参考 reshape 算子 |
| intermediate | 标记前向计算中输出的用于反向计算的中间变量,不会出现在 Python API 的返回结果中,相关设计正在完善中,新增算子时不建议使用 |
| invoke | 复用已有的算子接口或实现自定义的 C++ API,配置时以函数调用的形式配置即可,使用 invoke 时则不需要配置 infer_meta 和 kernel。 a. 如果是复用已有算子,需要被复用的算子为前向算子且两者的返回值类型相同,可参考 zeros_like 算子 b. 如果是实现自定义的 C++ API,需要在'paddle/phi/api/lib/api_custom_impl.h'声明自定义实现函数并在'paddle/phi/api/lib/api_custom_impl.cc'中进行实现,具体可参考 embedding 算子 |
需要说明的是因为dpsgd属于优化型算子,没有对应的backwardop所以这里不用配置,对于backwardop的配置后面会补充说明。另外,飞桨根据算子签名(ops.yaml、legacy_ops.yaml)自动生成了新IR体系下的算子定义。具体逻辑可实现在 paddle/fluid/pir/dialect/CMakeLists.txt 及其调用的相关脚本,生成的 OP 定义文件在build/paddle/fluid/pir/dialect/operator/ir/pd_op.cc
2.2 为dpsgd配置op_compat.yaml文件
ProgramTranslator需要确定应该将旧IR的哪个参数对应到新IR的哪个参数.这种映射定义在 paddle/phi/api/yaml/op_compat.yaml中.
一般地我们只需要将旧IR下对应驼峰命名转为新IR下的下划线命名即可.
- op: dpsgd inputs: {param: Param,grad: Grad,learning_rate: LearningRate} outputs: param_out: ParamOut由该yaml文件生成的cpp文件是paddle/fluid/ir_adaptor/translator/op_compat_info.cc该文件指导ProgramTranslator的翻译.
2.3 为dpsgd配置InferMeta
InferMeta函数是根据输入参数,推断算子输出 Tensor 基本信息的函数,推断的信息包括输出 Tensor 的shape、data type,同时它也承担了检查输入数据维度、类型等是否合法的功能。
说明:InferMeta 与 kernel 共同组成了一个算子的运算过程。InferMeta 在 kernel 前执行,用于维度、数据类型等信息的计算处理,这些信息在没有具体数据时依然可以通过输入参数完成输出结果的信息推导(例如两个维度为 2x3 的张量相加,输出结果的维度也一定是 2x3),可以利用这些信息优化训练过程中资源的分配和使用,kernel 中也不再需要专门推导这些信息。kernel 则用于具体数据的逻辑计算,为 InferMeta 函数推导得到的张量填充具体的结果值。
InferMeta和InferShape有什么区别?为什么不继续叫InferShape?
InferMeta 的 Meta 来源于 DenseTensor 中的 meta 成员,在 PHI 中,一个 op 有两大组件,InferMeta 和 Kernel。这里 InferMeta 覆盖了 InferShape 的功能,但又不限于 InferShape,除了对 dims 和 lod 的推断,InferMeta 中也会承担 dtype 和 layout 的推断,这一点和原先是不一样的。
修复Op单测时,并不需要我们真正去实现InferMeta,我们只需要根据需要修复Op的InferShape函数稍加修改即可,但是dtype信息需要我们单独设置一下,因为InferShape,不包含dtype信息.一般地,outputs的dtype信息要inputs的dtype一致即可.这里以dpsgd为例,介绍注册InferMeta的流程.
void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, platform::errors::NotFound( "Input(Param) of DpsgdOp should not be null.")); PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true, platform::errors::NotFound( "Input(Grad) of DpsgdOp should not be null.")); PADDLE_ENFORCE_EQ( ctx->HasInput("LearningRate"), true, platform::errors::NotFound( "Input(LearningRate) of DpsgdOp should not be null.")); PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(), framework::proto::VarType::LOD_TENSOR, platform::errors::InvalidArgument( "The input var's type should be phi::DenseTensor, " "but the received is %s", ctx->GetInputsVarType("Param").front())); PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad").front(), framework::proto::VarType::LOD_TENSOR, platform::errors::InvalidArgument( "The input var's type should be phi::DenseTensor, " "but the received is %s", ctx->GetInputsVarType("Grad").front())); PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true, platform::errors::NotFound( "Output(ParamOut) of DpsgdOp should not be null.")); auto lr_dims = ctx->GetInputDim("LearningRate"); PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1, platform::errors::InvalidArgument( "Learning rate should have 1 dimension. But Received " "LearningRate's dims [%s].", phi::product(lr_dims))); auto param_dims = ctx->GetInputDim("Param"); PADDLE_ENFORCE_EQ( param_dims, ctx->GetInputDim("Grad"), platform::errors::InvalidArgument( "Param and Grad input of DpsgdOp should have same dimension. But " "received Para's dim [%s] and Grad's dim [%s].", param_dims, ctx->GetInputDim("Grad"))); ctx->SetOutputDim("ParamOut", param_dims); }我们可以看到InferShape主要完成两个工作,首先是对InferShapeContext中是否具有某些参数进行检查,对与InferMeta可以直接忽略这部分逻辑,这里是因InferMeta的参数并不涉及InferShapeContext,它直接接收对应的变量,而不用去InferShapeContext中去查找.后面对某些变量进行维度推导的逻辑才是我们需要重点关注的,我们需要将这些逻辑照搬到InferMeta里面.
void DpsgdInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, float clip, float batch_size, float sigma, int size, MetaTensor* param_out) { auto lr_dims = learning_rate.dims(); PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1, phi::errors::InvalidArgument( "Learning rate should have 1 dimension. But Received " "LearningRate's dims [%s].", phi::product(lr_dims))); auto param_dims = param.dims(); PADDLE_ENFORCE_EQ( param_dims, grad.dims(), phi::errors::InvalidArgument( "Param and Grad input of DpsgdOp should have same dimension. But " "received Para's dim [%s] and Grad's dim [%s].", param_dims, grad.dims())); param_out->set_dims(param_dims); param_out->set_dtype(param.dtype()); }这里给出了DpsgdInferMeta的实现,可以发现和InferShape的实现基本一致,仅更换了一些API.另外,根据dtype的设置原则,将param_out(输出)的dtype设置为param(输入)的dtype即可.
最后,关于InferMeta函数的实现以及声明位置,可以参考文档开发C++算子的说明.
InferMeta 的实现位置
InferMeta 的文件放置规则(paddle/phi/infermeta 目录下,以 Tensor 输入个数为判定标准):
nullary.h:没有输入 Tensor 参数的函数unary.h:仅有一个输入 Tensor 参数的函数binary.h:有两个输入 Tensor 参数的函数ternary.h:有三个输入 Tensor 参数的函数multiary.h:有三个以上输入 Tensor 或者输入为vector<Tensor>的函数backward.h:反向算子的 InferMeta 函数一律在此文件中,不受前序规则限制
另外,函数在单个文件的中的排序方式是按照函数名称的字典序进行放置.
2.4 本地测试通过后提交Pr
在本地开启FLAGS_PIR_OPTEST 和 FLAGS_PIR_OPTEST_WHITE_LIST单测通过之后可以提交PR.另外,可以打开FLAGS_enable_pir_in_executor进行进一步测试,并在PR提交时说明执行情况,如果失败标明原因即可.
在新IR下对应单测跑通之后,可以提交PR,需要将单测名称加入到test/white_list/pir_op_test_white_list中,这样可以开启PR-CI-PY3,PR-CI-MAC-Python3和PR-CI-Coverage两条流水线上在新IR下执行该单测.
2.5 补充说明需要注册BackwardOp的情况
事实上大部分Op是具有反向算子的.一个算子是否具有反向算子,我们大致可以根据在旧Ir下定义该算子时有没有定义他的反向Op.这里以repeat_interleave为例.
class RepeatInterleaveOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) the input tensor."); AddInput("RepeatsTensor", "the 1-D tensor containing the repeats alongsize the axis.") .AsDispensable(); AddOutput("Out", "the output tensor."); AddAttr<int>("Repeats", "the number of repetitions for each element.") .SetDefault(0); AddAttr<int>("dim", "the dimension in which we repeat.").SetDefault(0); } }; template <typename T> class RepeatInterleaveGradMaker : public framework::SingleGradOpMaker<T> { public: using framework::SingleGradOpMaker<T>::SingleGradOpMaker; protected: void Apply(GradOpPtr<T> op) const override { op->SetType("repeat_interleave_grad"); op->SetInput("X", this->Input("X")); op->SetInput("RepeatsTensor", this->Input("RepeatsTensor")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetAttrMap(this->Attrs()); } };可以发现repeat_interleave是有反向op的,我们要补充他的反向算子的定义.
首先,我们定义正向算子
- op : repeat_interleave args : (Tensor x, int repeats, int axis) output : Tensor(out) infer_meta : func : RepeatInterleaveInferMeta kernel : func : repeat_interleave data_type : x backward: repeat_interleave_grad因为该算子具有反向算子所以我们需要配置backward.
然后,我们根据RepeatInterleaveGradMaker为其在对应的backward.yaml中注册反向算子.
- backward_op : repeat_interleave_grad forward : repeat_interleave(Tensor x, int repeats, int axis) -> Tensor(out) args : (Tensor x, Tensor out_grad, int repeats, int axis) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta param : [x] kernel : func : repeat_interleave_grad反向算子也是一种算子,所以其配置和正向算子大同小异,具体字段的含义同样需要参考文档开发C++算子的说明.
backward.yaml 中反向算子的配置规则如下:
| 配置项 | 配置内容及规则 |
|---|---|
| backward_op | 反向算子名称,一般命名方式为:前向算子名称+'_grad',二阶算子则为前向算子名称+'_double_grad' |
| forward | 对应前向算子的名称、参数、返回值,需要与 ops.yaml 中前向算子配置一致 |
| args | 反向算子输入参数, 示例中'x'表示将前向的'x'变量输入到反向,'out_grad'表示前向输出'out'对应的反向梯度 约束条件 1:所有参数需要在 forward 配置项的参数中(输入、输出以及输出对应的反向梯度)找到对应(根据变量名匹配) 约束条件 2:反向输入参数需要以:a.前向输入 Tensor b.前向输出 Tensor c.前向输出 Tensor 的反向梯度 d.前向非 Tensor 类型属性变量(Attribute) 的顺序排列,反向计算中不需要使用的前向变量无须添加 |
| output | 反向算子输出,顺序需要与前向输入 Tensor 一致,比如前向输入(Tensor x, Tensor y),则反向输出必须为 Tensor(x_grad), Tensor(y_grad) |
| infer_meta | 与前向配置规则相同 |
| kernel | 与前向配置规则相同 |
| backward | 反向算子对应的更高阶反向算子名称,如一阶反向算子的反向为二阶反向算子 |
| 特殊配置项(目前特殊配置项还处于不稳定阶段,后续可能会有调整更新) | |
| no_need_buffer | 可选配置,标记的 Tensor 变量在前向运行完成后,持有的内存或显存会被释放,以减少训练过程中的内存使用。trace_grad 由于反向算子只需要前向变量'x'的维度信息,不需要内存数据,所以可以标记为 no_need_buffer 提前释放内存 注意:由于 Tensor 内存被释放后会影响 dtype 接口的使用,所以需要在 kernel 的 data_type 配置项中指定其他的 Tensor 来推导 kernel 的 data_type |
| optional | 与前向配置规则相同 |
| inplace | 与前向配置规则相同 |
反向算子的op_compat.yaml不用特殊配置,只需要在为其对应的正向算子配置的同时,加入backward字段即可.具体地
- op : repeat_interleave backward : repeat_interleave_grad inputs : x : X outputs : out : Out attrs : {repeats : Repeats, axis : dim}另外,反向算子作为算子它也需要对应的InferMeta,配置方法和正向算子相同.
三、验收标准
分别打开如下两组flag:
FLAGS_PIR_OPTEST和FLAGS_PIR_OPTEST_WHITE_LISTFLAGS_enable_pir_in_executor
执行对应的单测,打开第一组falg单测通过即可将单测名称加入test/white_list/pir_op_test_white_list中,然后提交PR.但是需要在PR中补充说明在FLAGS_enable_pir_in_executor打开时的单测执行情况.它是否可以执行通过,如果不通过输出了什么样的报错信息.
四、单测执行的整体流程
Pir下执行单测大致可以分为计算图转换,StandaloneExecutor初始化和StandalneExecutor执行三个阶段.大体流程如下图所示:
下面展开对这些流程展开描述:
4.1 Python侧测试的流程
单测一般会调用在test/legacy_test/op_test.py中的OpTest的成员函数check_out.在OpTest中需要重点关注
_check_ir_output和_check_ir_grad_output两个函数,这里会涉及到几个flag的使用.
-
FLAGS_PIR_OPTEST和FLAGS_PIR_OPTEST_WHITE_LIST这两个FLAG同时打开时
_check_ir_output才会执行,该函数会暂存FLAGS_enable_pir_in_executor和FLAGS_pir_apply_inplace_pass两个flag,然后构建执行器,执行计算图,拿到结果进行比较. -
FLAGS_enable_pir_in_executor该
flag会在python/paddle/base/executor.py中控制是否将旧Ir下的计算图翻译为新Ir下的计算图,并执行. -
FLAGS_PIR_OPTEST_RELAX_CHECK该
flag表示可以对于新Ir下的计算图输出结果,允许一定的误差. -
FLAGS_PIR_NO_CHECK该
flag开启表示不关心当前计算图在新Ir下的执行结果,只要执行成功即可.一般使用在涉及到具有随机性的Op,比如Seed.
所以,新Ir下测试Op的流程主要逻辑为在旧Ir下执行一次计算图拿到预期结果,然后开启相应的flag在新Ir下再执行一次计算图将两次所得到的结果进行比较.
同样以dpsgd为例,下面展开说明python侧的执行流程
class TestDpsgdOp(OpTest): ... def test_check_output(self): self.check_output()它的单测继承自OpTest会直接执行OpTest的check_output方法.在OpTest的check_output中会调用check_output_with_place.可以发现check_output_with_place定义了内部类Checker,以及它的子类StaticChecker,DygraphChecker和PirChecker. check_output_with_place后续逻辑大体上也是围绕这三个测试类展开的,如下述代码基本逻辑是初始化一个测试类然后执行类里的check方法.
static_checker = StaticChecker(self, self.outputs) static_checker.check() outs, fetch_list = static_checker.outputs, static_checker.fetch_listcheck方法的基本逻辑是
def check(self): """ return None means ok, raise Error means failed. the main enter point of Checker class """ self.init() self.calculate_output() self.compare_outputs_with_expects()这里可以看出check方法首先去执行计算图得到输出,然后再将得到的输出和预期的结果进行比较.在calculate_output中会调用op_test._calc_output函数,在该函数中可以发现如下逻辑
executor = Executor(place) outs = executor.run( program, feed=feed_map, fetch_list=fetch_list, return_numpy=False, ) self._check_ir_output(place, program, feed_map, fetch_list, outs)这里会初始化一个执行器,然后执行该计算图.之后会执行_check_ir_output,
def _check_ir_output(self, place, program, feed_map, fetch_list, outs): if os.getenv("FLAGS_PIR_OPTEST") is None: return if os.getenv("FLAGS_PIR_OPTEST_WHITE_LIST") is None: return if self.check_prim or self.check_prim_pir: return if self._check_cinn: return stored_flag = get_flags( [ 'FLAGS_enable_pir_in_executor', "FLAGS_pir_apply_inplace_pass", ] ) try: set_flags( { "FLAGS_enable_pir_in_executor": True, "FLAGS_pir_apply_inplace_pass": 0, } ) new_scope = paddle.static.Scope() executor = Executor(place) new_program = None if isinstance(program, paddle.static.CompiledProgram): new_program = base.CompiledProgram( program._program, build_strategy=program._build_strategy ) else: new_program = program.clone() ir_outs = executor.run( new_program, feed=feed_map, fetch_list=fetch_list, return_numpy=False, scope=new_scope, ) assert len(outs) == len( ir_outs ), "Fetch result should have same length when executed in pir" check_method = np.testing.assert_array_equal if os.getenv("FLAGS_PIR_OPTEST_RELAX_CHECK", None) == "True": check_method = lambda x, y, z: np.testing.assert_allclose( x, y, err_msg=z, atol=1e-6, rtol=1e-6 ) if os.getenv("FLAGS_PIR_NO_CHECK", None) == "True": check_method = lambda x, y, err_msg: None for i in range(len(outs)): check_method( outs[i], ir_outs[i], err_msg='Operator Check (' + self.op_type + ') has diff at ' + str(place) + '\nExpect ' + str(outs[i]) + '\n' + 'But Got' + str(ir_outs[i]) + ' in class ' + self.__class__.__name__, ) finally: set_flags(stored_flag)通过阅读_check_ir_output源码,对于新IR下的单测执行就很清晰了,如果我们打开了之前提到的FLAGS_PIR_OPTEST和FLAGS_PIR_OPTEST_WHITE_LIST那么这个测试逻辑就会生效.它会提前保存FLAGS_enable_pir_in_executor和FLAGS_pir_apply_inplace_pass然后打开FLAGS_enable_pir_in_executor,再执行一次计算图,因为FLAGS_enable_pir_in_executor被打开这次测试会再新IR下执行,然后再对执行结果进行比较即可.
4.2 C++侧执行新Ir计算图流程
4.2.1 计算图的翻译
首先,新Ir表示下的计算图是通过翻译旧Ir下的计算图得到的,所以首先需要考虑的是计算图的翻译.在executor.py中使用函数translate_to_pir完成计算图的翻译.该函数被绑定到C++侧的paddle::TranslateLegacyProgramToProgram上.关于paddle::TranslateLegacyProgramToProgram的相关代码在paddle/fluid/ir_adaptor/translator文件夹下.ProgramTranslator的核心部分是OpTranslator.对于该部分的工作原理文档program_translator有详细介绍,为某个Op制定特殊的转换规则我们可以为该Op指定OpTranscriber,以下内容引用自该文档
虽然我们需要为某些 Op 定义特殊的转换规则,但是并不是所有的转换逻辑都是特殊的,比如说,有些时候我们只需要针对属性进行特殊处理,那么就没有再把其他部分的转换规则重复一遍.因此我们通过继承与成员函数的重载,允许只自定义转换流程中某一部分的转换规则.可以这样理解,一个 Op 的转换函数 OpTranslateFn 实际上是由若干个函数指针组成的,如果需要为某个 Op 定义特殊规则,一般只需要更改其中的一个或几个函数指针即可.
目前,我们支持重载的模块如下:public: virtual pir::Operation* operator()(pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, pir::Block* block); public: virtual pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc); virtual std::vector<pir::Value> GenerateOperationInput( pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, pir::Block* block); virtual std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput( pir::IrContext* ctx, const OpDesc& op_desc, const OpOutputInfoList& output_infos); virtual void HandleNonexistentAttribute(pir::IrContext*, pir::AttributeMap* attribute_map, const OpAttributeInfo& info); virtual pir::AttributeMap TranslateOpAttribute( pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc); virtual pir::OpResult GetAttributeAsInput(pir::IrContext* ctx, pir::Block* block, const OpDesc& op_desc, const OpInputInfo& input_info); virtual void RecordOpResultMapping(pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, pir::Operation* operation, const OpOutputMapping& arg_to_idx); public: virtual InputHandlerFn GetSpecialInputHandlers( const std::string& input_name) { return nullptr; } virtual AttributeHandlerFn GetSpecialAttributeHandlers( const std::string& input_name) { return nullptr; } virtual void InsertSliceOperationForInput(pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const OpInputInfoList& input_infos, pir::Block* block); };
在修复某个Op的单测时可能会需要涉及到为该Op设计特殊的转换规则,此时需要考虑重写上述函数.
4.2.2 StandaloneExecutor执行计算图
翻译成新Ir表示的计算图最终会提交给StandaloneExecutor执行器执行.下面展开StandaloneExecutor执行器执行计算图的流程.
4.2.2.1 StandaloneExecutor初始化
StandaloneExecutor(const platform::Place& place, const interpreter::Plan& plan_, Scope* scope);其中,place是指定的运算设备,plan_管理着需要执行器执行的计算图.scope用于管理变量,op在运行时从scope中查找输入和输出变量.
StandaloneExecutor的初始化过程分为下述步骤:
- 为每个计算图创建一个新的
scope加入mirco_batch_scopes_中. - 遍历
pir_progam中的每一个Op,统计fetch_var_name将其加入到fetch_var_names_中. - 调用
PdOpLowerToKernelPass得到更加底层的KernelDialect计算图描述. - 新建
InterpreterCore.
下面展开第3步的过程:
第3步主要由paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc中的ProcessBlock函数完成,该函数将由OperationDialect下的计算图转化为KernelDialect下的计算图表示.
首先,会通过GetOpYamlInfoParser函数得到op_info_parser,拿到op_info_parser之后就可以访问对应Op的OpInfoTupe,通过解析ops.yaml自动生成的对应Op的GetOpInfo函数定义在build/fluid/pir/dialect/operator/ir/pd_op.cc中.在GetOpInfo中保存了OpRunTimeInfo信息,定义如下:
struct OpRunTimeInfo { std::string infer_meta_func; std::vector<std::string> infer_meta_param; std::string kernel_func; std::vector<std::string> kernel_param; std::vector<std::string> kernel_key_dtype; std::vector<std::string> kernel_key_backend; std::vector<std::pair<std::string, std::string>> inplace; std::vector<std::pair<std::string, std::string>> view; };这里保存了对应Op的Kernel信息.ProcessBlock中的kernel_fn_str对应OpRuntimeInfo里面的kernel_func字段.
GetKernelKey的处理逻辑十分复杂,它主要返回Kernel的backend,layout以及dtype.这里可能最需要关注的是dtype因为配置的ops.yaml信息会直接影响到OpRunTimeInfo的kernel_key_dtype,而最终得到的KernelKey和kernel_key_dtype直接相关.
在得到KernelKey之后,会调用BuildOpInputList和BuildOutputType完成对inputs和op_output_types的build.
最后调用BuildPhiKernelOp完成KernelDialect下的计算图构建.
下面展开第4步过程:
在构建InterpreterCore时会构建InterpreterBaseImpl的实现PirInterpreter.在PirInterpreter的构造函数中会执行BuildScope创建inner_scpoe.inner_scope被ValueExecutionInfo管理.在BuildScope中为每个普通Op的Value实例化对应的Varibale.
4.2.2.2 StandaloneExecutor执行
执行StandaloneExecutor::Run在新Ir下会调用PirInterpreter中的Run函数.主要执行下述三个操作:
- BuildInstruction
- PreAnalysis
- TraceRunImpl/MultiThreadRunImpl
下面展开介绍这3个操作:
BuildInstruction
不考虑控制流Op,BuildInstruction主要涉及构建LegacyKernelInstruction和PhiKernelInstruction.这里以构建PhiKernelInstruction为例.
首先,执行SetKernelType主要是为该Op设置OpFuncType,OpFuncType描述Kernel的运行硬件环境信息.
然后,通过op_info中的GetInterfaceImpl得到InferMetaInterface::Concept,其中记录了该Op的InferMeta函数的函数指针.如果该Op存在InferMeta函数,则需要为其准备运行环境infer_meta_context.该步骤主要由BuildPhiContext函数实现.然后,根据Op里保存的kernel_name和kernel_key在Kernelfactory中选择合适的phi_kernel,使用BuildPhiContext函数构造出kernel_context.
PreAnalysis
该部分主要完成对instructions之间的依赖分析,帮助后续对Op执行调度,方便并发执行.为jit program的变量注册等待事件等.
TraceRunImle
惰性初始化GC,然后调用TraceRunInstructionList.TraceRunInstructionList会调用RunInstructionBase等待同步时机到来后,调度当前PhiKernelInstruction重写的Run函数执行.在Run函数中依次执行InferMeta和Kernel.
五、Q&A
1.对于randint这类随机性的Op两次执行的结果不相同怎么验证正确性?
A: 此类Op需要将单测名称加入test/white_list/pir_op_test_no_check_list,在执行CI测试时FLAGS_PIR_NO_CHECK会自动打开.参考PR#57826
2.报错The kernel with key (CPU, Undefined(AnyLayout), int64) of kernel seed is not registered是什么原因,如何修复?
A: 该错误主要是由新旧ir下的GetExpectedKernelType不一致造成的,旧Ir下kerneltype为INT32,而新ir下的GetExpectedKernelType返回的是Out的dtype,修改新ir下的GetExpectedKernelType问题解决.参考PR#58552
3.FatalError: Segmentation fault is detected by the operating system.调用栈最后执行的是paddle::operators::DataOp::GetExpectedKernelType(paddle::framework::ExecutionContext const&) const,这是什么原因造成的?
A: 主要是单测机制导致的测试在开启FLAGS_enable_new_ir_in_executor时执行错误,开启FLAGS_PIR_OPTEST, FLAGS_PIR_OPTEST_WHITE_LIST单测成功,可以暂时不做处理.
4.OpYamlInfoParser在解析runtime_info.kernel_param时会将可变属性放入kernel_fn_attr_params这样对于新Ir下定义的sparse_momentum_op(定义了Scalar axis)会造成AttributeMap中不存在axis属性的问题,如何解决?
A: 对于此类legacy op暂时将可变属性统一放入kernel_fn_tensor_params中.解决方案是需要给OpYamlInfoParser多增加一个属性,用来判断当前翻译的Op是非为legacy op.
5.旧Ir下的Op往往需要根据参数的不同翻译为新Ir的多个Op,如何实现?
A: 需要重写LoopkUpOpInfo函数,甚至根据输入的不同需要重写其他函数.可以参考PR#58379
6.报错error: 'eager_api_XXX' was not decalred in this scope如何解决?
A: 需要将op名称添加到paddle/fluid/pir/dialect/op_generator/ops_api_gen.py中的NO_NEED_GEN_STATIC_ONLY_APIS这里同样需要保证字典序.最后,再次编译,问题消失.
7.其他错误如何快速定位发生错误的位置?
A: 首先,一个单测文件中可以有多个测试,这里建议只保留一个出错的单测.然后,可以在padde/base/executor.py中直接print(program)打印出计算图,可以比较新旧Ir下计算图的异同.最后,可以通过export GLOG_v=10打开全量日志,观察日志可以大体确定出错的位置.确定错误位置不确定修改方案可以联系 @xingmingyyj和@kangguangli.
8.RuntimeError: (PreconditionNotMet) op [pd_op.xxx] kernel output args defs should equal op outputs此类问题的原因是什么?怎么解决?
A:此类问题是Legacy op的kernel和phi kernel的推导机制不一致造成的。如果kernel是通过PD_REGISTER_STRUCT_KERNEL注册的,需要把他加在LegacyOpList中,单独处理。
Metadata
Metadata
Assignees
Labels
Type
Projects
Status