@@ -53,7 +53,13 @@ class MADGRAD(torch.optim.Optimizer):
5353 """
5454
5555 def __init__ (
56- self , params : _params_t , lr : float = 1e-2 , momentum : float = 0.9 , weight_decay : float = 0 , eps : float = 1e-6 ,
56+ self ,
57+ params : _params_t ,
58+ lr : float = 1e-2 ,
59+ momentum : float = 0.9 ,
60+ weight_decay : float = 0 ,
61+ eps : float = 1e-6 ,
62+ decoupled_decay : bool = False ,
5763 ):
5864 if momentum < 0 or momentum >= 1 :
5965 raise ValueError (f"Momentum { momentum } must be in the range [0,1]" )
@@ -64,7 +70,8 @@ def __init__(
6470 if eps < 0 :
6571 raise ValueError (f"Eps must be non-negative" )
6672
67- defaults = dict (lr = lr , eps = eps , momentum = momentum , weight_decay = weight_decay )
73+ defaults = dict (
74+ lr = lr , eps = eps , momentum = momentum , weight_decay = weight_decay , decoupled_decay = decoupled_decay )
6875 super ().__init__ (params , defaults )
6976
7077 @property
@@ -95,7 +102,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
95102 for group in self .param_groups :
96103 eps = group ["eps" ]
97104 lr = group ["lr" ] + eps
98- decay = group ["weight_decay" ]
105+ weight_decay = group ["weight_decay" ]
99106 momentum = group ["momentum" ]
100107
101108 ck = 1 - momentum
@@ -120,11 +127,13 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
120127 s = state ["s" ]
121128
122129 # Apply weight decay
123- if decay != 0 :
124- if grad .is_sparse :
125- raise RuntimeError ("weight_decay option is not compatible with sparse gradients" )
126-
127- grad .add_ (p .data , alpha = decay )
130+ if weight_decay != 0 :
131+ if group ['decoupled_decay' ]:
132+ p .data .mul_ (1.0 - group ['lr' ] * weight_decay )
133+ else :
134+ if grad .is_sparse :
135+ raise RuntimeError ("weight_decay option is not compatible with sparse gradients" )
136+ grad .add_ (p .data , alpha = weight_decay )
128137
129138 if grad .is_sparse :
130139 grad = grad .coalesce ()
0 commit comments