Skip to content
32 changes: 25 additions & 7 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,32 @@ OperatorBase::OperatorBase(const std::string& type,
const OperatorBase::VarNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
auto op_info_it = OpRegistry::op_info_map().find(type_);

if (op_info_it == OpRegistry::op_info_map().end()) {
return;
}

auto* op_proto = op_info_it->second.proto_;
if (op_proto == nullptr) {
return;
}

static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
// If op_proto is registered for current operator.
// We will generate output names if user not set.
for (auto& output : op_proto->outputs()) {
// If outputs is duplicable, that output could be [0,N]'s outputs. Default
// names cannot be set.
if (output.duplicable()) {
continue;
}

auto& outs = outputs_[output.name()];
// Set default output name, if it is not set
if (outs.empty()) {
outs.push_back(type_ + "@GENERATE_OUTPUT@" +
std::to_string(gUniqId.fetch_add(1)));
}
}
}
Expand Down
57 changes: 26 additions & 31 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ bool IsCompileGPU() {
#endif
}

static std::unique_ptr<OperatorBase> create_op_from_pb(
const py::bytes &protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
}

PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");

Expand Down Expand Up @@ -185,16 +196,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("__str__", string::to_string<const platform::CPUPlace &>);

py::class_<OperatorBase>(m, "Operator")
.def_static("create",
[](py::bytes protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def_static("create", create_op_from_pb)
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
Expand All @@ -216,33 +218,26 @@ All parameter, weight, gradient are variables in Paddle.
.def("support_gpu", &OperatorBase::SupportGPU);

py::class_<operators::NetOp, OperatorBase>(m, "Net")
.def_static("create",
[]() -> operators::NetOp * {
auto *retv = new operators::NetOp;
retv->SetType("plain_net");
return retv;
})
.def(py::init<>())
.def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
});
.def("complete_add_op",
[](operators::NetOp &self) { self.CompleteAddOp(); })
.def("create_and_add_op",
[](operators::NetOp &self, const py::bytes &str) {
self.AddOp(create_op_from_pb(str));
return self.ops_.back().get();
},
py::return_value_policy::reference);

// recurrent_op
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
.def_static(
"create",
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def_static("create",
[](const py::bytes &str) {
return static_cast<operators::RecurrentOp *>(
create_op_from_pb(str).release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone());
Expand Down
36 changes: 36 additions & 0 deletions python/paddle/v2/framework/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import paddle.v2.framework.core as core
from paddle.v2.framework.op import OpDescCreationMethod, get_all_op_protos


class Network(object):
def __init__(self):
self.net = core.Net()

def add_op(self, op, **kwargs):
if len(kwargs) == 0:
if isinstance(op, Network):
self.add_op(op.net)
else:
self.net.add_op(op)
else:
if not isinstance(op, str) and not isinstance(op, unicode):
raise TypeError("Op should be str/unicode or another operator")
all_protos = get_all_op_protos()
if op not in all_protos:
raise RuntimeError("Op %s has not been registered", op)
method = OpDescCreationMethod(get_all_op_protos()[op])
op_desc = method(**kwargs)
op = self.net.create_and_add_op(op_desc.SerializeToString())
outs = op.no_intermediate_outputs()
if len(outs) == 1:
return outs[0]
elif len(outs) == 0:
return None
else:
return outs

def __str__(self):
return str(self.net)

def complete_add_op(self):
self.net.complete_add_op()
20 changes: 11 additions & 9 deletions python/paddle/v2/framework/op.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2

g_all_op_protos = None


def get_all_op_protos():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
"""
global g_all_op_protos
protostrs = core.get_all_op_protos()
ret_values = []
for pbstr in protostrs:
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
ret_values.append(op_proto)
return ret_values
if g_all_op_protos is None:
g_all_op_protos = dict()
for pbstr in protostrs:
op_proto = framework_pb2.OpProto.FromString(str(pbstr))
g_all_op_protos[op_proto.type] = op_proto
return g_all_op_protos


def is_str(s):
Expand Down Expand Up @@ -141,7 +145,7 @@ def __impl__(*args, **kwargs):
class OperatorFactory(object):
def __init__(self):
self.op_methods = dict()
for op_proto in get_all_op_protos():
for op_proto in get_all_op_protos().values():
method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method

Expand Down Expand Up @@ -184,9 +188,7 @@ class __RecurrentOp__(object):
def __init__(self):
# cache recurrent_op's proto
if self.__proto__ is None:
for op_proto in get_all_op_protos():
if op_proto.type == self.type:
self.__proto__ = op_proto
self.__proto__ = get_all_op_protos()[self.type]

def __call__(self, *args, **kwargs):
if self.type not in args and 'type' not in kwargs:
Expand Down
1 change: 1 addition & 0 deletions python/paddle/v2/framework/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_network SRCS test_network.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
6 changes: 3 additions & 3 deletions python/paddle/v2/framework/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def fc(X, W, Y):
ret_v = core.Net.create()
ret_v = core.Net()

ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y))
Expand All @@ -14,11 +14,11 @@ def fc(X, W, Y):

class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.Net.create()
net = core.Net()
op1 = Operator("add_two", X="X", Y="Y", Out="Out")
net.add_op(op1)

net2 = core.Net.create()
net2 = core.Net()
net2.add_op(fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
Expand Down
19 changes: 19 additions & 0 deletions python/paddle/v2/framework/tests/test_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from paddle.v2.framework.network import Network
import unittest


class TestNetwork(unittest.TestCase):
def test_add_op(self):
net = Network()
out = net.add_op("add_two", X="A", Y="B")
out = net.add_op("mul", X=out, Y="D")
net2 = Network()
net2.add_op("add_two", X=out, Y="E")
net2.complete_add_op()
net.add_op(net2)
net.complete_add_op()
print str(net)


if __name__ == '__main__':
unittest.main()
3 changes: 1 addition & 2 deletions python/paddle/v2/framework/tests/test_operator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import unittest
import paddle.v2.framework.op as op
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2


class TestGetAllProtos(unittest.TestCase):
def test_all(self):
all_protos = op.get_all_op_protos()
all_protos = op.get_all_op_protos().values()
self.assertNotEqual(0, len(all_protos))

for each in all_protos:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/v2/framework/tests/test_recurrent_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def create_rnn_op(self):
memories=["h@alias"])

def create_step_net(self):
stepnet = core.Net.create()
stepnet = core.Net()
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add_two", X="Wx", Y="Uh", Out="sum")
Expand Down