Skip to content

Commit a6af48b

Browse files
committed
add madgradw optimizer
1 parent 55fb5ee commit a6af48b

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

tests/test_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def test_lamb(optimizer):
490490
_test_model(optimizer, dict(lr=1e-3))
491491

492492

493-
@pytest.mark.parametrize('optimizer', ['madgrad'])
493+
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
494494
def test_madgrad(optimizer):
495495
_test_basic_cases(
496496
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)

timm/optim/madgrad.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ class MADGRAD(torch.optim.Optimizer):
5353
"""
5454

5555
def __init__(
56-
self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6,
56+
self,
57+
params: _params_t,
58+
lr: float = 1e-2,
59+
momentum: float = 0.9,
60+
weight_decay: float = 0,
61+
eps: float = 1e-6,
62+
decoupled_decay: bool = False,
5763
):
5864
if momentum < 0 or momentum >= 1:
5965
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
@@ -64,7 +70,8 @@ def __init__(
6470
if eps < 0:
6571
raise ValueError(f"Eps must be non-negative")
6672

67-
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
73+
defaults = dict(
74+
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
6875
super().__init__(params, defaults)
6976

7077
@property
@@ -95,7 +102,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
95102
for group in self.param_groups:
96103
eps = group["eps"]
97104
lr = group["lr"] + eps
98-
decay = group["weight_decay"]
105+
weight_decay = group["weight_decay"]
99106
momentum = group["momentum"]
100107

101108
ck = 1 - momentum
@@ -120,11 +127,13 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
120127
s = state["s"]
121128

122129
# Apply weight decay
123-
if decay != 0:
124-
if grad.is_sparse:
125-
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
126-
127-
grad.add_(p.data, alpha=decay)
130+
if weight_decay != 0:
131+
if group['decoupled_decay']:
132+
p.data.mul_(1.0 - group['lr'] * weight_decay)
133+
else:
134+
if grad.is_sparse:
135+
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
136+
grad.add_(p.data, alpha=weight_decay)
128137

129138
if grad.is_sparse:
130139
grad = grad.coalesce()

timm/optim/optim_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def create_optimizer_v2(
165165
optimizer = Lamb(parameters, **opt_args)
166166
elif opt_lower == 'madgrad':
167167
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
168+
elif opt_lower == 'madgradw':
169+
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
168170
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
169171
optimizer = NvNovoGrad(parameters, **opt_args)
170172
elif opt_lower == 'rmsprop':

0 commit comments

Comments
 (0)