@@ -84,13 +84,11 @@ class Lamb(Optimizer):
8484 """
8585
8686 def __init__ (
87- self , params , lr = 1e-3 , bias_correction = True , betas = (0.9 , 0.999 ), eps = 1e-6 , weight_decay = 0.01 ,
88- grad_averaging = True , max_grad_norm = 1.0 , decoupled_decay = False , use_nvlamb = False ):
87+ self , params , lr = 1e-3 , bias_correction = True , betas = (0.9 , 0.999 ), eps = 1e-6 ,
88+ weight_decay = 0.01 , grad_averaging = True , max_grad_norm = 1.0 , use_nvlamb = False ):
8989 defaults = dict (
90- lr = lr , bias_correction = bias_correction ,
91- betas = betas , eps = eps , weight_decay = weight_decay ,
92- grad_averaging = grad_averaging , max_grad_norm = max_grad_norm ,
93- decoupled_decay = decoupled_decay , use_nvlamb = use_nvlamb )
90+ lr = lr , bias_correction = bias_correction , betas = betas , eps = eps , weight_decay = weight_decay ,
91+ grad_averaging = grad_averaging , max_grad_norm = max_grad_norm , use_nvlamb = use_nvlamb )
9492 super ().__init__ (params , defaults )
9593
9694 def step (self , closure = None ):
@@ -136,8 +134,6 @@ def step(self, closure=None):
136134 else :
137135 group ['step' ] = 1
138136
139- step_size = group ['lr' ]
140-
141137 if bias_correction :
142138 bias_correction1 = 1 - beta1 ** group ['step' ]
143139 bias_correction2 = 1 - beta2 ** group ['step' ]
@@ -157,11 +153,6 @@ def step(self, closure=None):
157153 # Exponential moving average of squared gradient values
158154 state ['exp_avg_sq' ] = torch .zeros_like (p .data )
159155
160- decoupled_decay = group ['decoupled_decay' ]
161- weight_decay = group ['weight_decay' ]
162- if decoupled_decay and weight_decay != 0 :
163- p .data .mul_ (1. - group ['lr' ] * weight_decay )
164-
165156 exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
166157
167158 # Decay the first and second moment running average coefficient
@@ -171,7 +162,8 @@ def step(self, closure=None):
171162 denom = (exp_avg_sq .sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
172163 update = (exp_avg / bias_correction1 ).div_ (denom )
173164
174- if not decoupled_decay and weight_decay != 0 :
165+ weight_decay = group ['weight_decay' ]
166+ if weight_decay != 0 :
175167 update .add_ (p .data , alpha = weight_decay )
176168
177169 trust_ratio = one_tensor
@@ -186,6 +178,6 @@ def step(self, closure=None):
186178 one_tensor ,
187179 )
188180 update .mul_ (trust_ratio )
189- p .data .add_ (update , alpha = - step_size )
181+ p .data .add_ (update , alpha = - group [ 'lr' ] )
190182
191183 return loss
0 commit comments