Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions paddle/framework/op_proto.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ message AttrProto {

// Supported attribute comments. It helps 3rd-party language generate doc-string.
required string comment = 3;

// If that attribute is generated, it means the Paddle third language
// binding has responsibility to fill that attribute. End-User should
// not set that attribute.
optional bool generated = 4 [default=false];
}

// Input or output message for 3rd-party language binding.
Expand All @@ -45,6 +50,40 @@ message VarProto {

// The comment for that input. It helps 3rd-party language generate doc-string.
required string comment = 2;

// Is that input/output could be a list or not.
// If so, that Op should write a attributed named `input_format` or
// `output_format`.
//
// e.g.
// If the op is a fc op, the inputs are `X`, `W`, `b`. The `X` and `W`
// could be multiple, so the multiple of `X` and `W` is True, and OpDesc
// will hold a attribute of them.
//
// The Op desc of same fc could be
// {
// "type": "fc",
// "input": ["X1", "X2", "W1", "W2", "b"],
// "output": "fc.out",
// "attrs" : {
// "input_format": [0, 2, 4, 5]
// }
// }
//
optional bool multiple = 3 [default=false];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if repeated a better name?


// It marks that output is a temporary output. That output is not used by
// user, but used by other op internally as input. If other op is not use
// that output, it could be optimized early.
//
// Attribute temporary_index will be set in OpDesc if there is some
// outputs are temporary.
//
// output = [ "xxx.out1", "xxx.tmp", "xxx.out2"],
// attrs = {
// "temporary_index": [1]
// }
optional bool temporary = 4 [default=false];
}

// Op protocol message for 3rd-party language binding.
Expand Down
97 changes: 94 additions & 3 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <algorithm>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
Expand Down Expand Up @@ -59,25 +61,52 @@ class OpProtoAndCheckerMaker {
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}

~OpProtoAndCheckerMaker() { CheckNoDuplicatedAttrs(); }

protected:
void AddInput(const std::string& name, const std::string& comment) {
void AddInput(const std::string& name, const std::string& comment,
bool multiple = false) {
auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name;
*input->mutable_comment() = comment;
input->set_multiple(multiple);
if (multiple) {
SetHasMultipleInput();
}
}

void AddInputs(const std::string& name, const std::string& comment) {
AddInput(name, comment, true);
}

void AddOutput(const std::string& name, const std::string& comment) {
void AddOutput(const std::string& name, const std::string& comment,
bool temporary = false, bool multiple = false) {
auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name;
*output->mutable_comment() = comment;
output->set_multiple(multiple);
if (multiple) {
SetHasMultipleOutput();
}
output->set_temporary(temporary);
if (temporary) {
SetHasTemporaryOutput();
}
}

void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false) {
AddOutput(name, comment, temporary, true);
}

template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment) {
const std::string& comment,
bool generated = false) {
auto attr = proto_->mutable_attrs()->Add();
*attr->mutable_name() = name;
*attr->mutable_comment() = comment;
attr->set_generated(generated);
AttrTypeHelper::SetAttrType<T>(attr);
return op_checker_->AddAttrChecker<T>(name);
}
Expand All @@ -86,8 +115,70 @@ class OpProtoAndCheckerMaker {
*(proto_->mutable_comment()) = comment;
}

private:
void SetHasMultiple(const std::string& in_out, bool* flag) {
if (!*flag) {
AddAttr<std::vector<int>>(in_out + "_format",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each variable length input will has a input format ?

Copy link
Collaborator Author

@reyoung reyoung Jul 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input_format is an attribute for that Op.

"The multiple index of " + in_out +
"\n"
R"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.

e.g.
input = ["a", "b", "c", "d", "e", "f"]
input_format = [0, 4, 5, 6]

means
The number of all input variables this op is six, and they are segmented into
three inputs.

The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC",
/*generated*/ true);
*flag = true;
}
}

void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); }
void SetHasMultipleOutput() {
SetHasMultiple("output", &has_multiple_output_);
}

void SetHasTemporaryOutput() {
if (!has_temporary_output_) {
AddAttr<std::vector<int>>("temporary_index",
R"DOC(The temporary index of output.

Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.

Add a mark to which output is temporary is helpful for future optimization.
)DOC",
/*generated*/ true)
.SetDefault(std::vector<int>());
has_temporary_output_ = true;
}
}

void CheckNoDuplicatedAttrs() {
std::unordered_set<std::string> names;
size_t cnt = 0;
for (auto& attr : proto_->attrs()) {
names.insert(attr.name());
++cnt;
}
PADDLE_ENFORCE(names.size() == cnt,
"Cannot register two attribute in same name!");
}

OpProto* proto_;
OpAttrChecker* op_checker_;
bool has_multiple_input_{false};
bool has_multiple_output_{false};
bool has_temporary_output_{false};
};

class OpRegistry {
Expand Down
15 changes: 13 additions & 2 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddInputs("input", "input of cosine op");
AddOutput("output", "output of cosine op",
/*temporary*/ true);
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
Expand Down Expand Up @@ -117,11 +118,20 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
}

static void SetInputFormat(paddle::framework::OpDesc* desc) {
auto attr = desc->add_attrs();
attr->set_name("input_format");
attr->set_type(paddle::framework::INTS);
attr->mutable_ints()->Add(0);
attr->mutable_ints()->Add(1);
}

TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op");
op_desc.add_inputs("ii");
op_desc.add_outputs("oo");
SetInputFormat(&op_desc);

// attr 'test_attr' is not set
bool caught = false;
Expand Down Expand Up @@ -163,6 +173,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4);
SetInputFormat(&op_desc);
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
Expand Down