11import paddle .v2 .framework .core as core
22import collections
3+ import numpy as np
4+ import copy
35
46__all__ = ['Block' , 'Variable' , 'Program' , 'Operator' ]
57
68
79class Variable (object ):
8- def __init__ (self , block , name = None , shape = None , dtype = None ,
9- lod_level = None ):
10+ def __init__ (self ,
11+ block ,
12+ name = None ,
13+ shape = None ,
14+ dtype = None ,
15+ lod_level = None ,
16+ ** kwargs ):
1017 self .block = block
1118
1219 if name is None :
1320 name = Variable ._unique_var_name_ ()
14- self .proto = self .block .proto .new_var (name )
21+ try :
22+ self .desc = self .block .desc .var (name )
23+ is_new_var = False
24+ except core .EnforceNotMet :
25+ self .desc = self .block .desc .new_var (name )
26+ is_new_var = True
1527
1628 if shape is not None :
17- self .proto .set_shape (shape )
18-
29+ if is_new_var :
30+ self .desc .set_shape (shape )
31+ else :
32+ old_shape = self .shape
33+ shape = tuple (shape )
34+ if shape != old_shape :
35+ raise ValueError (
36+ "Variable {0} has been created before. the previous "
37+ "shape is {1}; the new shape is {2}. They are not "
38+ "matched." .format (self .name , old_shape , shape ))
1939 if dtype is not None :
20- # TODO(yuyang18): Convert dtype from numpy.dtype
21- self .proto .set_data_type (dtype )
40+ if not isinstance (dtype , core .DataType ):
41+ dtype = Variable ._convert_np_dtype_to_dtype_ (dtype )
42+ if is_new_var :
43+ self .desc .set_data_type (dtype )
44+ else :
45+ old_dtype = self .data_type ()
46+ if dtype != old_shape :
47+ raise ValueError ("Variable {0} has been created before. "
48+ "The previous data type is {1}; the new "
49+ "data type is {2}. They are not "
50+ "matched." .format (self .name , old_dtype ,
51+ dtype ))
2252
2353 if lod_level is not None :
24- # TODO(yuyang18): set_lod_level is not defined.
25- self .proto .set_lod_level (lod_level )
26-
54+ if is_new_var :
55+ self .desc .set_lod_level (lod_level )
56+ else :
57+ if lod_level != self .lod_level :
58+ raise ValueError ("Variable {0} has been created before. "
59+ "The previous lod_level is {1}; the new "
60+ "lod_level is {2}. They are not "
61+ "matched" .format (self .name , self .lod_level ,
62+ lod_level ))
2763 self .block .vars [name ] = self
2864 self .op = None
2965
30- # TODO(yuyang18): Get methods
66+ @property
67+ def name (self ):
68+ return self .desc .name ()
69+
70+ @property
71+ def shape (self ):
72+ # convert to tuple, make it as same as numpy API.
73+ return tuple (self .desc .shape ())
74+
75+ @property
76+ def data_type (self ):
77+ return self .desc .data_type ()
78+
79+ @property
80+ def lod_level (self ):
81+ return self .desc .lod_level ()
3182
3283 @staticmethod
3384 def _unique_var_name_ ():
3485 uid = core .unique_integer () # unique during whole process.
3586 return "_generated_var_%d" % uid
3687
88+ @staticmethod
89+ def _convert_np_dtype_to_dtype_ (np_dtype ):
90+ dtype = np .dtype (np_dtype )
91+ if dtype == np .float32 :
92+ return core .DataType .FP32
93+ elif dtype == np .float64 :
94+ return core .DataType .FP64
95+ elif dtype == np .float16 :
96+ return core .DataType .FP16
97+ elif dtype == np .int32 :
98+ return core .DataType .INT32
99+ elif dtype == np .int16 :
100+ return core .DataType .INT16
101+ elif dtype == np .int64 :
102+ return core .DataType .INT64
103+ elif dtype == np .bool :
104+ return core .DataType .BOOL
105+ else :
106+ raise ValueError ("Not supported numpy dtype " + str (dtype ))
107+
37108
38109class Operator (object ):
39110 def __init__ (self ,
40111 block ,
41- proto ,
112+ desc ,
42113 type = None ,
43114 inputs = None ,
44115 outputs = None ,
45116 attrs = None ):
46117 self .block = block
47- self .proto = proto
118+ self .desc = desc
48119 if type is not None :
49120 # TODO.
50121 pass
@@ -58,36 +129,40 @@ def __init__(self,
58129 # TODO
59130 pass
60131
61- # TODO: Getters
132+ # TODO: Getters
62133
63134
64135class Block (object ):
65136 def __init__ (self , program , idx ):
66- self .proto = program .proto .block (idx )
137+ self .desc = program .desc .block (idx )
67138 self .vars = dict () # var_name --> var
68139 self .ops = collections .deque () # operator list
69140 self .program = program
70141
71142 @property
72143 def parent_idx (self ):
73- return self .proto .parent
144+ return self .desc .parent
74145
75146 @property
76147 def idx (self ):
77- return self .proto .id
148+ return self .desc .id
78149
79150 def create_var (self , * args , ** kwargs ):
80151 return Variable (self , * args , ** kwargs )
81152
153+ def create_parameter (self , * args , ** kwargs ):
154+ global_block = self .program .global_block ()
155+ return Parameter (global_block , * args , ** kwargs )
156+
82157 def append_op (self , * args , ** kwargs ):
83- op_proto = self .proto .append_op ()
84- op = Operator (self , op_proto , * args , ** kwargs )
158+ op_desc = self .desc .append_op ()
159+ op = Operator (self , op_desc , * args , ** kwargs )
85160 self .ops .append (op )
86161 return op
87162
88163 def prepend_op (self , * args , ** kwargs ):
89- op_proto = self .proto .prepend_op ()
90- op = Operator (self , op_proto , * args , ** kwargs )
164+ op_desc = self .desc .prepend_op ()
165+ op = Operator (self , op_desc , * args , ** kwargs )
91166 self .ops .appendleft (op )
92167 return op
93168
@@ -104,7 +179,7 @@ def instance(cls):
104179 def __init__ (self ):
105180 assert not hasattr (self .__class__ ,
106181 '_instance' ), 'Do not call constructor directly!'
107- self .proto = core .ProgramDesc .instance ()
182+ self .desc = core .ProgramDesc .instance ()
108183 self .blocks = [Block (self , 0 )]
109184 self .current_block_idx = 0
110185
@@ -116,7 +191,7 @@ def current_block(self):
116191
117192 def create_block (self ):
118193 new_block_idx = len (self .blocks )
119- self .proto .append_block (self .current_block ().proto )
194+ self .desc .append_block (self .current_block ().desc )
120195 self .current_block_idx = new_block_idx
121196 self .blocks .append (Block (self , self .current_block_idx ))
122197 return self .current_block ()
@@ -125,5 +200,41 @@ def rollback(self):
125200 self .current_block_idx = self .current_block ().parent_idx
126201
127202
203+ class Parameter (Variable ):
204+ def __init__ (self , block , shape , dtype , ** kwargs ):
205+ if shape is None or dtype is None :
206+ raise ValueError ("Parameter must set shape and dtype" )
207+ if len (shape ) == 0 :
208+ raise ValueError ("Parameter shape cannot be empty" )
209+
210+ for each in shape :
211+ if each < 0 :
212+ raise ValueError ("Parameter shape should not be related with "
213+ "batch-size" )
214+
215+ Variable .__init__ (self , block , shape = shape , dtype = dtype , ** kwargs )
216+ self .trainable = kwargs .get ('trainable' , True )
217+ self .init_attr = kwargs .get ('initialize_attr' , {
218+ 'type' : 'uniform_random' ,
219+ 'min' : - 1.0 ,
220+ 'max' : 1.0
221+ })
222+
223+ self .optimize_attr = kwargs .get ('optimize_attr' , {'learning_rate' : 1.0 })
224+ self ._append_initialize_ops_ ()
225+
226+ def _append_initialize_ops_ (self ):
227+ attr = copy .deepcopy (self .init_attr )
228+ op_type = attr .pop ('type' , None )
229+ block = self .block
230+ assert isinstance (block , Block )
231+ shape = self .shape
232+ attr ['dims' ] = shape
233+ attr ['data_type' ] = int (self .data_type )
234+ op = block .prepend_op (
235+ type = op_type , inputs = None , outputs = {'Out' : [self ]}, attrs = attr )
236+ self .op = op
237+
238+
128239# program is a global instance.
129240g_program = Program .instance ()
0 commit comments