Skip to content

Commit 55fb5ee

Browse files
committed
Remove experiment from lamb impl
1 parent 8a9eca5 commit 55fb5ee

File tree

4 files changed

+10
-20
lines changed

4 files changed

+10
-20
lines changed

tests/test_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_adafactor(optimizer):
463463
_test_model(optimizer, dict(lr=5e-2))
464464

465465

466-
@pytest.mark.parametrize('optimizer', ['lamb', 'lambw'])
466+
@pytest.mark.parametrize('optimizer', ['lamb'])
467467
def test_lamb(optimizer):
468468
_test_basic_cases(
469469
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)

timm/optim/adabelief.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class AdaBelief(Optimizer):
1818
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
1919
algorithm from the paper `On the Convergence of Adam and Beyond`_
2020
(default: False)
21-
decoupled_decay (boolean, optional): ( default: True) If set as True, then
21+
decoupled_decay (boolean, optional): (default: True) If set as True, then
2222
the optimizer uses decoupled weight decay as in AdamW
2323
fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
2424
is set as True.
@@ -194,7 +194,7 @@ def step(self, closure=None):
194194
denom = exp_avg_var.sqrt().add_(group['eps'])
195195
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
196196
elif step_size > 0:
197-
p.data.add_( exp_avg, alpha=-step_size * group['lr'])
197+
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
198198

199199
if half_precision:
200200
p.data = p.data.half()

timm/optim/lamb.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,11 @@ class Lamb(Optimizer):
8484
"""
8585

8686
def __init__(
87-
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
88-
grad_averaging=True, max_grad_norm=1.0, decoupled_decay=False, use_nvlamb=False):
87+
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
88+
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False):
8989
defaults = dict(
90-
lr=lr, bias_correction=bias_correction,
91-
betas=betas, eps=eps, weight_decay=weight_decay,
92-
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
93-
decoupled_decay=decoupled_decay, use_nvlamb=use_nvlamb)
90+
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
91+
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
9492
super().__init__(params, defaults)
9593

9694
def step(self, closure=None):
@@ -136,8 +134,6 @@ def step(self, closure=None):
136134
else:
137135
group['step'] = 1
138136

139-
step_size = group['lr']
140-
141137
if bias_correction:
142138
bias_correction1 = 1 - beta1 ** group['step']
143139
bias_correction2 = 1 - beta2 ** group['step']
@@ -157,11 +153,6 @@ def step(self, closure=None):
157153
# Exponential moving average of squared gradient values
158154
state['exp_avg_sq'] = torch.zeros_like(p.data)
159155

160-
decoupled_decay = group['decoupled_decay']
161-
weight_decay = group['weight_decay']
162-
if decoupled_decay and weight_decay != 0:
163-
p.data.mul_(1. - group['lr'] * weight_decay)
164-
165156
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
166157

167158
# Decay the first and second moment running average coefficient
@@ -171,7 +162,8 @@ def step(self, closure=None):
171162
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
172163
update = (exp_avg / bias_correction1).div_(denom)
173164

174-
if not decoupled_decay and weight_decay != 0:
165+
weight_decay = group['weight_decay']
166+
if weight_decay != 0:
175167
update.add_(p.data, alpha=weight_decay)
176168

177169
trust_ratio = one_tensor
@@ -186,6 +178,6 @@ def step(self, closure=None):
186178
one_tensor,
187179
)
188180
update.mul_(trust_ratio)
189-
p.data.add_(update, alpha=-step_size)
181+
p.data.add_(update, alpha=-group['lr'])
190182

191183
return loss

timm/optim/optim_factory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,6 @@ def create_optimizer_v2(
163163
optimizer = Adafactor(parameters, **opt_args)
164164
elif opt_lower == 'lamb':
165165
optimizer = Lamb(parameters, **opt_args)
166-
elif opt_lower == 'lambw':
167-
optimizer = Lamb(parameters, decoupled_decay=True, **opt_args) # FIXME experimental
168166
elif opt_lower == 'madgrad':
169167
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
170168
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':

0 commit comments

Comments
 (0)