Skip to content

Commit 630af4d

Browse files
Jiaming-Liusoumith
authored andcommitted
add learning rate schedulers (pytorch#1370)
1 parent 0409b42 commit 630af4d

File tree

3 files changed

+478
-1
lines changed

3 files changed

+478
-1
lines changed

docs/source/optim.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,21 @@ Algorithms
114114
:members:
115115
.. autoclass:: SGD
116116
:members:
117+
118+
How to adjust Learning Rate
119+
---------------------------
120+
121+
:mod:`torch.optim.lr_scheduler` provides several methods to adjust the learning
122+
rate based on the number of epoches. :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`
123+
allows dynamic learning rate reducing based on some validation measurements.
124+
125+
.. autoclass:: torch.optim.lr_scheduler.LambdaLR
126+
:members:
127+
.. autoclass:: torch.optim.lr_scheduler.StepLR
128+
:members:
129+
.. autoclass:: torch.optim.lr_scheduler.MultiStepLR
130+
:members:
131+
.. autoclass:: torch.optim.lr_scheduler.ExponentialLR
132+
:members:
133+
.. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau
134+
:members:

test/test_optim.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import torch
55
import torch.optim as optim
66
import torch.legacy.optim as old_optim
7+
import torch.nn.functional as F
8+
from torch.optim import SGD
79
from torch.autograd import Variable
810
from torch import sparse
9-
11+
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
1012
from common import TestCase, run_tests
1113

1214

@@ -392,5 +394,157 @@ def test_invalid_param_type(self):
392394
optim.SGD(Variable(torch.randn(5, 5)), lr=3)
393395

394396

397+
class SchedulerTestNet(torch.nn.Module):
398+
def __init__(self):
399+
super(SchedulerTestNet, self).__init__()
400+
self.conv1 = torch.nn.Conv2d(1, 1, 1)
401+
self.conv2 = torch.nn.Conv2d(1, 1, 1)
402+
403+
def forward(self, x):
404+
return self.conv2(F.relu(self.conv1(x)))
405+
406+
407+
class TestLRScheduler(TestCase):
408+
def setUp(self):
409+
self.net = SchedulerTestNet()
410+
self.opt = SGD(
411+
[{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
412+
lr=0.05)
413+
414+
def test_step_lr(self):
415+
# lr = 0.05 if epoch < 3
416+
# lr = 0.005 if 30 <= epoch < 6
417+
# lr = 0.0005 if epoch >= 9
418+
single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
419+
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
420+
scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
421+
epochs = 10
422+
self._test(scheduler, targets, epochs)
423+
424+
def test_multi_step_lr(self):
425+
# lr = 0.05 if epoch < 2
426+
# lr = 0.005 if 2 <= epoch < 5
427+
# lr = 0.0005 if epoch < 9
428+
# lr = 0.00005 if epoch >= 9
429+
single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
430+
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
431+
scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
432+
epochs = 10
433+
self._test(scheduler, targets, epochs)
434+
435+
def test_exp_lr(self):
436+
single_targets = [0.05 * (0.9 ** x) for x in range(10)]
437+
targets = [single_targets, list(map(lambda x: x * 10, single_targets))]
438+
scheduler = ExponentialLR(self.opt, gamma=0.9)
439+
epochs = 10
440+
self._test(scheduler, targets, epochs)
441+
442+
def test_reduce_lr_on_plateau1(self):
443+
for param_group in self.opt.param_groups:
444+
param_group['lr'] = 0.5
445+
targets = [[0.5] * 20]
446+
metrics = [10 - i * 0.0167 for i in range(20)]
447+
scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min',
448+
threshold=0.01, patience=5, cooldown=5)
449+
epochs = 10
450+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
451+
452+
def test_reduce_lr_on_plateau2(self):
453+
for param_group in self.opt.param_groups:
454+
param_group['lr'] = 0.5
455+
targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
456+
metrics = [10 - i * 0.0165 for i in range(22)]
457+
scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
458+
mode='min', threshold=0.1)
459+
epochs = 22
460+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
461+
462+
def test_reduce_lr_on_plateau3(self):
463+
for param_group in self.opt.param_groups:
464+
param_group['lr'] = 0.5
465+
targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
466+
metrics = [-0.8] * 2 + [-0.234] * 20
467+
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5,
468+
threshold_mode='abs')
469+
epochs = 22
470+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
471+
472+
def test_reduce_lr_on_plateau4(self):
473+
for param_group in self.opt.param_groups:
474+
param_group['lr'] = 0.5
475+
targets = [[0.5] * 20]
476+
metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25
477+
scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3,
478+
threshold_mode='rel', threshold=0.1)
479+
epochs = 20
480+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
481+
482+
def test_reduce_lr_on_plateau5(self):
483+
for param_group in self.opt.param_groups:
484+
param_group['lr'] = 0.5
485+
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
486+
metrics = [1.5 * (1.005 ** i) for i in range(20)]
487+
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel',
488+
threshold=0.1, patience=5, cooldown=5)
489+
epochs = 20
490+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
491+
492+
def test_reduce_lr_on_plateau6(self):
493+
for param_group in self.opt.param_groups:
494+
param_group['lr'] = 0.5
495+
targets = [[0.5] * 20]
496+
metrics = [1.5 * (0.85 ** i) for i in range(20)]
497+
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
498+
threshold=0.1)
499+
epochs = 20
500+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
501+
502+
def test_reduce_lr_on_plateau7(self):
503+
for param_group in self.opt.param_groups:
504+
param_group['lr'] = 0.5
505+
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
506+
metrics = [1] * 7 + [0.6] + [0.5] * 12
507+
scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel',
508+
threshold=0.1, patience=5, cooldown=5)
509+
epochs = 20
510+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
511+
512+
def test_reduce_lr_on_plateau8(self):
513+
for param_group in self.opt.param_groups:
514+
param_group['lr'] = 0.5
515+
targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
516+
metrics = [1.5 * (1.005 ** i) for i in range(20)]
517+
scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3],
518+
threshold=0.1, patience=5, cooldown=5)
519+
epochs = 20
520+
self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
521+
522+
def test_lambda_lr(self):
523+
self.opt.param_groups[0]['lr'] = 0.05
524+
self.opt.param_groups[1]['lr'] = 0.4
525+
targets = [[0.05 * (0.9 ** x) for x in range(10)], [0.4 * (0.8 ** x) for x in range(10)]]
526+
scheduler = LambdaLR(self.opt,
527+
lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2])
528+
epochs = 10
529+
self._test(scheduler, targets, epochs)
530+
531+
def _test(self, scheduler, targets, epochs=10):
532+
for epoch in range(epochs):
533+
scheduler.step(epoch)
534+
for param_group, target in zip(self.opt.param_groups, targets):
535+
self.assertAlmostEqual(target[epoch], param_group['lr'],
536+
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
537+
epoch, target[epoch], param_group['lr']), delta=1e-5)
538+
539+
def _test_reduce_lr_on_plateau(self, scheduler, targets, metrics, epochs=10, verbose=False):
540+
for epoch in range(epochs):
541+
scheduler.step(metrics[epoch])
542+
if verbose:
543+
print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr']))
544+
for param_group, target in zip(self.opt.param_groups, targets):
545+
self.assertAlmostEqual(target[epoch], param_group['lr'],
546+
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
547+
epoch, target[epoch], param_group['lr']), delta=1e-5)
548+
395549
if __name__ == '__main__':
396550
run_tests()

0 commit comments

Comments
 (0)