Skip to content

Commit ee22a43

Browse files
authored
Merge pull request #4684 from reyoung/feature/parameter
Feature/parameter
2 parents 843ed8e + f185af8 commit ee22a43

File tree

7 files changed

+222
-30
lines changed

7 files changed

+222
-30
lines changed

doc/design/python_api.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
2222
```python
2323
class Program(objects):
2424
def __init__(self):
25-
self.proto = core.NewProgram() # a C++ ProgramDesc pointer.
25+
self.desc = core.NewProgram() # a C++ ProgramDesc pointer.
2626
self.blocks = vector<Block>()
2727
self.blocks.append(Block(self, -1)) # the global block
2828
self.current_block = 0 # initialized to the global block
@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
5757
```python
5858
class Block(objects):
5959
def __init__(self, program, parent_idx):
60-
self.proto = core.NewBlock(program.proto)
60+
self.desc = core.NewBlock(program.desc)
6161
self.program = program
6262
self.vars = map<string, Variable>()
6363
self.ops = vector<Operator>()
@@ -98,11 +98,11 @@ class Operator(object):
9898
outputs,# dict<stirng, Variable>
9999
attrs # dict<string, Any>
100100
):
101-
self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs)
102-
core.infer_shape(self.proto, inputs, outputs)
101+
self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs)
102+
core.infer_shape(self.desc, inputs, outputs)
103103

104104
def type(self):
105-
return self.proto.type()
105+
return self.desc.type()
106106
```
107107

108108
`Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
@@ -124,7 +124,7 @@ class Variable(object):
124124
name = unique_name_generator()
125125
self.name = name
126126
self.block = block
127-
self.proto = core.NewVarDesc(block.proto, name, shape, lod_level)
127+
self.desc = core.NewVarDesc(block.desc, name, shape, lod_level)
128128
self.writer = None
129129
```
130130

paddle/framework/var_desc.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,13 @@ std::vector<int64_t> VarDescBind::Shape() const {
3232
DataType VarDescBind::GetDataType() const {
3333
return desc_.lod_tensor().data_type();
3434
}
35+
36+
void VarDescBind::SetLoDLevel(int32_t lod_level) {
37+
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
38+
}
39+
40+
int32_t VarDescBind::GetLodLevel() const {
41+
return desc_.lod_tensor().lod_level();
42+
}
3543
} // namespace framework
3644
} // namespace paddle

paddle/framework/var_desc.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class VarDescBind {
6666

6767
DataType GetDataType() const;
6868

69+
void SetLoDLevel(int32_t lod_level);
70+
71+
int32_t GetLodLevel() const;
72+
6973
private:
7074
VarDesc desc_;
7175
};

paddle/pybind/protobuf.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ void BindVarDsec(py::module &m) {
166166
.def("set_shape", &VarDescBind::SetShape)
167167
.def("set_data_type", &VarDescBind::SetDataType)
168168
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
169-
.def("data_type", &VarDescBind::GetDataType);
169+
.def("data_type", &VarDescBind::GetDataType)
170+
.def("lod_level", &VarDescBind::GetLodLevel)
171+
.def("set_lod_level", &VarDescBind::SetLoDLevel);
170172
}
171173

172174
void BindOpDesc(py::module &m) {
Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,121 @@
11
import paddle.v2.framework.core as core
22
import collections
3+
import numpy as np
4+
import copy
35

46
__all__ = ['Block', 'Variable', 'Program', 'Operator']
57

68

79
class Variable(object):
8-
def __init__(self, block, name=None, shape=None, dtype=None,
9-
lod_level=None):
10+
def __init__(self,
11+
block,
12+
name=None,
13+
shape=None,
14+
dtype=None,
15+
lod_level=None,
16+
**kwargs):
1017
self.block = block
1118

1219
if name is None:
1320
name = Variable._unique_var_name_()
14-
self.proto = self.block.proto.new_var(name)
21+
try:
22+
self.desc = self.block.desc.var(name)
23+
is_new_var = False
24+
except core.EnforceNotMet:
25+
self.desc = self.block.desc.new_var(name)
26+
is_new_var = True
1527

1628
if shape is not None:
17-
self.proto.set_shape(shape)
18-
29+
if is_new_var:
30+
self.desc.set_shape(shape)
31+
else:
32+
old_shape = self.shape
33+
shape = tuple(shape)
34+
if shape != old_shape:
35+
raise ValueError(
36+
"Variable {0} has been created before. the previous "
37+
"shape is {1}; the new shape is {2}. They are not "
38+
"matched.".format(self.name, old_shape, shape))
1939
if dtype is not None:
20-
# TODO(yuyang18): Convert dtype from numpy.dtype
21-
self.proto.set_data_type(dtype)
40+
if not isinstance(dtype, core.DataType):
41+
dtype = Variable._convert_np_dtype_to_dtype_(dtype)
42+
if is_new_var:
43+
self.desc.set_data_type(dtype)
44+
else:
45+
old_dtype = self.data_type()
46+
if dtype != old_shape:
47+
raise ValueError("Variable {0} has been created before. "
48+
"The previous data type is {1}; the new "
49+
"data type is {2}. They are not "
50+
"matched.".format(self.name, old_dtype,
51+
dtype))
2252

2353
if lod_level is not None:
24-
# TODO(yuyang18): set_lod_level is not defined.
25-
self.proto.set_lod_level(lod_level)
26-
54+
if is_new_var:
55+
self.desc.set_lod_level(lod_level)
56+
else:
57+
if lod_level != self.lod_level:
58+
raise ValueError("Variable {0} has been created before. "
59+
"The previous lod_level is {1}; the new "
60+
"lod_level is {2}. They are not "
61+
"matched".format(self.name, self.lod_level,
62+
lod_level))
2763
self.block.vars[name] = self
2864
self.op = None
2965

30-
# TODO(yuyang18): Get methods
66+
@property
67+
def name(self):
68+
return self.desc.name()
69+
70+
@property
71+
def shape(self):
72+
# convert to tuple, make it as same as numpy API.
73+
return tuple(self.desc.shape())
74+
75+
@property
76+
def data_type(self):
77+
return self.desc.data_type()
78+
79+
@property
80+
def lod_level(self):
81+
return self.desc.lod_level()
3182

3283
@staticmethod
3384
def _unique_var_name_():
3485
uid = core.unique_integer() # unique during whole process.
3586
return "_generated_var_%d" % uid
3687

88+
@staticmethod
89+
def _convert_np_dtype_to_dtype_(np_dtype):
90+
dtype = np.dtype(np_dtype)
91+
if dtype == np.float32:
92+
return core.DataType.FP32
93+
elif dtype == np.float64:
94+
return core.DataType.FP64
95+
elif dtype == np.float16:
96+
return core.DataType.FP16
97+
elif dtype == np.int32:
98+
return core.DataType.INT32
99+
elif dtype == np.int16:
100+
return core.DataType.INT16
101+
elif dtype == np.int64:
102+
return core.DataType.INT64
103+
elif dtype == np.bool:
104+
return core.DataType.BOOL
105+
else:
106+
raise ValueError("Not supported numpy dtype " + str(dtype))
107+
37108

38109
class Operator(object):
39110
def __init__(self,
40111
block,
41-
proto,
112+
desc,
42113
type=None,
43114
inputs=None,
44115
outputs=None,
45116
attrs=None):
46117
self.block = block
47-
self.proto = proto
118+
self.desc = desc
48119
if type is not None:
49120
# TODO.
50121
pass
@@ -58,36 +129,40 @@ def __init__(self,
58129
# TODO
59130
pass
60131

61-
# TODO: Getters
132+
# TODO: Getters
62133

63134

64135
class Block(object):
65136
def __init__(self, program, idx):
66-
self.proto = program.proto.block(idx)
137+
self.desc = program.desc.block(idx)
67138
self.vars = dict() # var_name --> var
68139
self.ops = collections.deque() # operator list
69140
self.program = program
70141

71142
@property
72143
def parent_idx(self):
73-
return self.proto.parent
144+
return self.desc.parent
74145

75146
@property
76147
def idx(self):
77-
return self.proto.id
148+
return self.desc.id
78149

79150
def create_var(self, *args, **kwargs):
80151
return Variable(self, *args, **kwargs)
81152

153+
def create_parameter(self, *args, **kwargs):
154+
global_block = self.program.global_block()
155+
return Parameter(global_block, *args, **kwargs)
156+
82157
def append_op(self, *args, **kwargs):
83-
op_proto = self.proto.append_op()
84-
op = Operator(self, op_proto, *args, **kwargs)
158+
op_desc = self.desc.append_op()
159+
op = Operator(self, op_desc, *args, **kwargs)
85160
self.ops.append(op)
86161
return op
87162

88163
def prepend_op(self, *args, **kwargs):
89-
op_proto = self.proto.prepend_op()
90-
op = Operator(self, op_proto, *args, **kwargs)
164+
op_desc = self.desc.prepend_op()
165+
op = Operator(self, op_desc, *args, **kwargs)
91166
self.ops.appendleft(op)
92167
return op
93168

@@ -104,7 +179,7 @@ def instance(cls):
104179
def __init__(self):
105180
assert not hasattr(self.__class__,
106181
'_instance'), 'Do not call constructor directly!'
107-
self.proto = core.ProgramDesc.instance()
182+
self.desc = core.ProgramDesc.instance()
108183
self.blocks = [Block(self, 0)]
109184
self.current_block_idx = 0
110185

@@ -116,7 +191,7 @@ def current_block(self):
116191

117192
def create_block(self):
118193
new_block_idx = len(self.blocks)
119-
self.proto.append_block(self.current_block().proto)
194+
self.desc.append_block(self.current_block().desc)
120195
self.current_block_idx = new_block_idx
121196
self.blocks.append(Block(self, self.current_block_idx))
122197
return self.current_block()
@@ -125,5 +200,41 @@ def rollback(self):
125200
self.current_block_idx = self.current_block().parent_idx
126201

127202

203+
class Parameter(Variable):
204+
def __init__(self, block, shape, dtype, **kwargs):
205+
if shape is None or dtype is None:
206+
raise ValueError("Parameter must set shape and dtype")
207+
if len(shape) == 0:
208+
raise ValueError("Parameter shape cannot be empty")
209+
210+
for each in shape:
211+
if each < 0:
212+
raise ValueError("Parameter shape should not be related with "
213+
"batch-size")
214+
215+
Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
216+
self.trainable = kwargs.get('trainable', True)
217+
self.init_attr = kwargs.get('initialize_attr', {
218+
'type': 'uniform_random',
219+
'min': -1.0,
220+
'max': 1.0
221+
})
222+
223+
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
224+
self._append_initialize_ops_()
225+
226+
def _append_initialize_ops_(self):
227+
attr = copy.deepcopy(self.init_attr)
228+
op_type = attr.pop('type', None)
229+
block = self.block
230+
assert isinstance(block, Block)
231+
shape = self.shape
232+
attr['dims'] = shape
233+
attr['data_type'] = int(self.data_type)
234+
op = block.prepend_op(
235+
type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr)
236+
self.op = op
237+
238+
128239
# program is a global instance.
129240
g_program = Program.instance()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import unittest
2+
from paddle.v2.framework.graph import g_program
3+
import paddle.v2.framework.core as core
4+
5+
6+
class TestParameter(unittest.TestCase):
7+
def test_param(self):
8+
b = g_program.create_block()
9+
param = b.create_parameter(
10+
name='fc.w',
11+
shape=[784, 100],
12+
dtype='float32',
13+
initialize_attr={
14+
'type': 'uniform_random',
15+
'seed': 13,
16+
'min': -5.0,
17+
'max': 5.0
18+
})
19+
self.assertIsNotNone(param)
20+
self.assertEqual('fc.w', param.name)
21+
self.assertEqual((784, 100), param.shape)
22+
self.assertEqual(core.DataType.FP32, param.data_type)
23+
self.assertEqual(0, param.block.idx)
24+
25+
26+
if __name__ == '__main__':
27+
unittest.main()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
from paddle.v2.framework.graph import Variable, g_program
3+
import paddle.v2.framework.core as core
4+
import numpy as np
5+
6+
7+
class TestVariable(unittest.TestCase):
8+
def test_np_dtype_convert(self):
9+
DT = core.DataType
10+
convert = Variable._convert_np_dtype_to_dtype_
11+
self.assertEqual(DT.FP32, convert(np.float32))
12+
self.assertEqual(DT.FP16, convert("float16"))
13+
self.assertEqual(DT.FP64, convert("float64"))
14+
self.assertEqual(DT.INT32, convert("int32"))
15+
self.assertEqual(DT.INT16, convert("int16"))
16+
self.assertEqual(DT.INT64, convert("int64"))
17+
self.assertEqual(DT.BOOL, convert("bool"))
18+
self.assertRaises(ValueError, lambda: convert("int8"))
19+
20+
def test_var(self):
21+
b = g_program.current_block()
22+
w = b.create_var(
23+
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
24+
self.assertEqual(core.DataType.FP64, w.data_type)
25+
self.assertEqual((784, 100), w.shape)
26+
self.assertEqual("fc.w", w.name)
27+
self.assertEqual(0, w.lod_level)
28+
29+
w = b.create_var(name='fc.w')
30+
self.assertEqual(core.DataType.FP64, w.data_type)
31+
self.assertEqual((784, 100), w.shape)
32+
self.assertEqual("fc.w", w.name)
33+
self.assertEqual(0, w.lod_level)
34+
35+
self.assertRaises(ValueError,
36+
lambda: b.create_var(name="fc.w", shape=(24, 100)))
37+
38+
39+
if __name__ == '__main__':
40+
unittest.main()

0 commit comments

Comments
 (0)