Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)

py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace framework {

OperatorBase* GradOpCreator::Create() {
OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op);
return grad_op;
}

OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
const VarIndexMap& var_map,
const std::vector<int>& format,
InOutType type) {
Expand All @@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
end_idx);
}

void GradOpCreator::BuildOpInOutArgList() {
void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format =
Expand All @@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() {
}
}

void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out,
std::vector<int>& format,
VarIndexMap* varmap, int& idx,
Expand All @@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
format.push_back(in_out.size());
}

void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ struct OpInOutArg {
size_t end_idx_;
};

class GradOpCreator {
class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>;

public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
GradOpBuilder(const OperatorBase* op) : op_(op) {}
OperatorBase* Build();

private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/grad_op_builder.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
Expand All @@ -8,7 +8,7 @@ USE_OP(add_two);
namespace paddle {
namespace framework {

TEST(GradOpCreator, AddTwo) {
TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
Expand Down
57 changes: 31 additions & 26 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"

Expand Down Expand Up @@ -222,7 +222,7 @@ class OpRegistry {
public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; };
op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
Expand All @@ -245,17 +245,19 @@ class OpRegistry {
}
}

template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
}

static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(),
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);

auto op = op_create_it->second();
Expand Down Expand Up @@ -300,8 +302,8 @@ class OpRegistry {

static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
GradOpBuilder builder(op.get());
std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init();
return grad_op;
}
Expand All @@ -311,9 +313,9 @@ class OpRegistry {
return protos_;
};

static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_;
}

static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
Expand All @@ -322,12 +324,12 @@ class OpRegistry {
return maps_;
}

private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}

private:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why make op_creators public? In my mind, it is only used by the CreatorOp and RegisterOp.

Copy link
Collaborator Author

@JiayiFeng JiayiFeng Jul 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also used by GradOpBuilder. Because all gradient operators are also registered in op_creators.

static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
Expand All @@ -353,11 +355,11 @@ class OpRegisterHelper {
}
};

template <typename OpType>
template <typename GradOpType>
class GradOpRegisterHelper {
public:
GradOpRegisterHelper(const char* op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type);
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
}
};

Expand All @@ -383,13 +385,16 @@ class GradOpRegisterHelper {
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
#__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}

/**
* Macro to Register OperatorKernel.
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace paddle

REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
6 changes: 3 additions & 3 deletions paddle/operators/add_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two");
ASSERT_NE(it, protos.end());
auto& grad_creators = paddle::framework::OpRegistry::grad_creators();
auto it1 = grad_creators.find("add_two");
ASSERT_NE(it1, grad_creators.end());
auto& op_creators = paddle::framework::OpRegistry::op_creators();
auto it1 = op_creators.find("add_two_grad");
ASSERT_NE(it1, op_creators.end());
}
2 changes: 1 addition & 1 deletion paddle/operators/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle

REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad);
REGISTER_GRADIENT_OP(mul, mul_grad, paddle::operators::MulOpGrad);

REGISTER_OP_CPU_KERNEL(
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
2 changes: 1 addition & 1 deletion paddle/operators/sigmoid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
REGISTER_OP(sigmoid,
paddle::operators::SigmoidOp,
paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, paddle::operators::SigmoidOpGrad);

REGISTER_OP_CPU_KERNEL(
sigmoid,
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/softmax_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;

REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad);
REGISTER_GRADIENT_OP(softmax, softmax_grad, paddle::operators::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);