Skip to content

Conversation

@chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Jan 9, 2022

PR types

Function optimization

PR changes

Others

Describe

[PTen] Move GetExpectedPtenKernelArgs function into pten for infrt

infrt在执行时,会直接载入program解析,并根据program的信息去选择kernel,但不会依赖paddle的op体系,因此,用于Op参数和kernel参数的映射函数GetExpectedPtenKernelArgs需要作为pten的组件,可以由框架灵活使用,并且不反向依赖fluid的实现

  1. 对fluid的适配,见代码中修改的示例
  • TODO:适配上,可以直接放到执行体系中,而不是在每个op中都重写GetExpectedPtenKernelArgs方法,鉴于一次性修改涉及文件较多,会在后续PR逐步修改
  1. 对infrt的适配,可以通过继承创建合适的ArgumentMappingContext,从而在infrt中调用相应匹配函数
// 伪代码 class ProtoArgumentMappingContext : public pten::ArgumentMappingContext { public: ProtoArgumentMappingContext(proto::OpProto* op, proto::BlockDesc* block) : op_(op), block_(block) {} bool HasInput(const std::string& name) const override { // simple search for (int i = 0; i < proto_->input_size(); ++i) { auto& in = proto_->inputs()[i]; if (in.name() == name) { return true; } } return false; } bool HasOutput(const std::string& name) const override { // simple search for (int i = 0; i < proto_->output_size(); ++i) { auto& out = proto_->outputs()[i]; if (out.name() == name) { return true; } } return false; } bool HasAttr(const std::string& name) const override { // simple search for (int i = 0; i < proto_->attrs_size(); ++i) { auto& attr = proto_->attrs()[i]; if (attr.name() == name) { return true; } } return false; } size_t InputSize(const std::string& name) const override { return proto_->input_size(); } size_t OutputSize(const std::string& name) const override { return proto_->output_size(); } bool IsDenseTensorInput(const std::string& name) const override { for (int i = 0; i < block_.vars_size(); ++i) { auto& var = block_.vars()[i]; if (var.name() == name) { if (var.type() == proto::VarType::LOD_TENSOR) { return true; } } } // TODO(chenweihang): throw error when cannot found return false; } bool IsSelectedRowsInput(const std::string& name) const override { for (int i = 0; i < block_.vars_size(); ++i) { auto& var = block_.vars()[i]; if (var.name() == name) { if (var.type() == proto::VarType::SELECTED_ROWS) { return true; } } } // TODO(chenweihang): throw error when cannot found return false; } private: proto::OpProto op_*; proto::BlockDesc block_*; }; 
@paddle-bot-old
Copy link

paddle-bot-old bot commented Jan 9, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.


namespace pten {

KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉如果这个函数也能注册到一个Map里,原来的Op可能就不用再单独写GetExpectedPtenKernelArgs了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,下一个PR完成

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit 3a23c1a into PaddlePaddle:develop Jan 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants