@@ -33,14 +33,12 @@ def setUp(self):
3333
3434 self .attrs  =  {'epsilon' : epsilon , 'beta1' : beta1 , 'beta2' : beta2 }
3535
36-  param_out , moment1_out , moment2_out ,  beta1_pow_out ,  \
37-  beta2_pow_out  =  adam_step (self .inputs , self .attrs )
36+  param_out , moment1_out , \
37+  moment2_out  =  adam_step (self .inputs , self .attrs )
3838
3939 self .outputs  =  {
4040 'Moment1Out' : moment1_out ,
4141 'Moment2Out' : moment2_out ,
42-  'Beta1PowOut' : beta1_pow_out ,
43-  'Beta2PowOut' : beta2_pow_out ,
4442 'ParamOut' : param_out 
4543 }
4644
@@ -78,14 +76,12 @@ def setUp(self):
7876
7977 attributes  =  {'epsilon' : epsilon , 'beta1' : beta1 , 'beta2' : beta2 }
8078
81-  param_out , moment1_out , moment2_out ,  beta1_pow_out ,  \
82-  beta2_pow_out  =  adam_step (self .inputs , attributes )
79+  param_out , moment1_out , \
80+  moment2_out  =  adam_step (self .inputs , attributes )
8381
8482 self .outputs  =  {
8583 'Moment1Out' : moment1_out ,
8684 'Moment2Out' : moment2_out ,
87-  'Beta1PowOut' : beta1_pow_out ,
88-  'Beta2PowOut' : beta2_pow_out ,
8985 'ParamOut' : param_out 
9086 }
9187
@@ -127,14 +123,12 @@ def setUp(self):
127123
128124 def  test_check_output (self ):
129125 for  _  in  range (self .num_steps ):
130-  param_out , moment1_out , moment2_out ,  beta1_pow_out ,  \
131-  beta2_pow_out  =  adam_step (self .inputs , self .attrs )
126+  param_out , moment1_out , \
127+  moment2_out  =  adam_step (self .inputs , self .attrs )
132128
133129 self .outputs  =  {
134130 'Moment1Out' : moment1_out ,
135131 'Moment2Out' : moment2_out ,
136-  'Beta1PowOut' : beta1_pow_out ,
137-  'Beta2PowOut' : beta2_pow_out ,
138132 'ParamOut' : param_out 
139133 }
140134
@@ -145,8 +139,10 @@ def test_check_output(self):
145139 self .inputs ['Param' ] =  param_out 
146140 self .inputs ['Moment1' ] =  moment1_out 
147141 self .inputs ['Moment2' ] =  moment2_out 
148-  self .inputs ['Beta1Pow' ] =  beta1_pow_out 
149-  self .inputs ['Beta2Pow' ] =  beta2_pow_out 
142+ 
143+  # Update powers of Beta1 and Beta2 for next time step 
144+  self .inputs ['Beta1Pow' ] *=  self .attrs ['beta1' ]
145+  self .inputs ['Beta2Pow' ] *=  self .attrs ['beta1' ]
150146
151147 # Randomize gradient for next step 
152148 self .inputs ['Grad' ] =  np .random .uniform (
@@ -175,11 +171,9 @@ def adam_step(inputs, attributes):
175171
176172 moment1_out  =  beta1  *  moment1  +  (1  -  beta1 ) *  grad 
177173 moment2_out  =  beta2  *  moment2  +  (1  -  beta2 ) *  np .square (grad )
178-  beta1_pow_out  =  beta1_pow  *  beta1 
179-  beta2_pow_out  =  beta2_pow  *  beta2 
180-  lr_t  =  lr  *  np .sqrt (1  -  beta2_pow_out ) /  (1  -  beta1_pow_out )
174+  lr_t  =  lr  *  np .sqrt (1  -  beta2_pow ) /  (1  -  beta1_pow )
181175 param_out  =  param  -  lr_t  *  (moment1_out  /  (np .sqrt (moment2_out ) +  epsilon ))
182-  return  param_out , moment1_out , moment2_out ,  beta1_pow_out ,  beta2_pow_out 
176+  return  param_out , moment1_out , moment2_out 
183177
184178
185179if  __name__  ==  "__main__" :
0 commit comments