View source on GitHub |
A LearningRateSchedule that uses a cosine decay schedule with restarts.
Inherits From: LearningRateSchedule
tf.keras.optimizers.schedules.CosineDecayRestarts( initial_learning_rate, first_decay_steps, t_mul=2.0, m_mul=1.0, alpha=0.0, name='SGDRDecay' ) See Loshchilov & Hutter, ICLR2016, SGDR: Stochastic Gradient Descent with Warm Restarts.
When training a model, it is often useful to lower the learning rate as the training progresses. This schedule applies a cosine decay function with restarts to an optimizer step, given a provided initial learning rate. It requires a step value to compute the decayed learning rate. You can just pass a backend variable that you increment at each training step.
The schedule is a 1-arg callable that produces a decayed learning rate when passed the current optimizer step. This can be useful for changing the learning rate value across different invocations of optimizer functions.
The learning rate multiplier first decays from 1 to alpha for first_decay_steps steps. Then, a warm restart is performed. Each new warm restart runs for t_mul times more steps and with m_mul times initial learning rate as the new learning rate.
Example:
first_decay_steps = 1000 lr_decayed_fn = ( keras.optimizers.schedules.CosineDecayRestarts( initial_learning_rate, first_decay_steps)) You can pass this schedule directly into a keras.optimizers.Optimizer as the learning rate. The learning rate schedule is also serializable and deserializable using keras.optimizers.schedules.serialize and keras.optimizers.schedules.deserialize.
Returns | |
|---|---|
A 1-arg callable learning rate schedule that takes the current optimizer step and outputs the decayed learning rate, a scalar tensor of the same type as initial_learning_rate. |
Methods
from_config
@classmethodfrom_config( config )
Instantiates a LearningRateSchedule from its config.
| Args | |
|---|---|
config | Output of get_config(). |
| Returns | |
|---|---|
A LearningRateSchedule instance. |
get_config
get_config() __call__
__call__( step ) Call self as a function.
View source on GitHub