|
4 | 4 | import torch |
5 | 5 | import torch.optim as optim |
6 | 6 | import torch.legacy.optim as old_optim |
| 7 | +import torch.nn.functional as F |
| 8 | +from torch.optim import SGD |
7 | 9 | from torch.autograd import Variable |
8 | 10 | from torch import sparse |
9 | | - |
| 11 | +from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau |
10 | 12 | from common import TestCase, run_tests |
11 | 13 |
|
12 | 14 |
|
@@ -392,5 +394,157 @@ def test_invalid_param_type(self): |
392 | 394 | optim.SGD(Variable(torch.randn(5, 5)), lr=3) |
393 | 395 |
|
394 | 396 |
|
| 397 | +class SchedulerTestNet(torch.nn.Module): |
| 398 | + def __init__(self): |
| 399 | + super(SchedulerTestNet, self).__init__() |
| 400 | + self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| 401 | + self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| 402 | + |
| 403 | + def forward(self, x): |
| 404 | + return self.conv2(F.relu(self.conv1(x))) |
| 405 | + |
| 406 | + |
| 407 | +class TestLRScheduler(TestCase): |
| 408 | + def setUp(self): |
| 409 | + self.net = SchedulerTestNet() |
| 410 | + self.opt = SGD( |
| 411 | + [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], |
| 412 | + lr=0.05) |
| 413 | + |
| 414 | + def test_step_lr(self): |
| 415 | + # lr = 0.05 if epoch < 3 |
| 416 | + # lr = 0.005 if 30 <= epoch < 6 |
| 417 | + # lr = 0.0005 if epoch >= 9 |
| 418 | + single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 |
| 419 | + targets = [single_targets, list(map(lambda x: x * 10, single_targets))] |
| 420 | + scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| 421 | + epochs = 10 |
| 422 | + self._test(scheduler, targets, epochs) |
| 423 | + |
| 424 | + def test_multi_step_lr(self): |
| 425 | + # lr = 0.05 if epoch < 2 |
| 426 | + # lr = 0.005 if 2 <= epoch < 5 |
| 427 | + # lr = 0.0005 if epoch < 9 |
| 428 | + # lr = 0.00005 if epoch >= 9 |
| 429 | + single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 |
| 430 | + targets = [single_targets, list(map(lambda x: x * 10, single_targets))] |
| 431 | + scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| 432 | + epochs = 10 |
| 433 | + self._test(scheduler, targets, epochs) |
| 434 | + |
| 435 | + def test_exp_lr(self): |
| 436 | + single_targets = [0.05 * (0.9 ** x) for x in range(10)] |
| 437 | + targets = [single_targets, list(map(lambda x: x * 10, single_targets))] |
| 438 | + scheduler = ExponentialLR(self.opt, gamma=0.9) |
| 439 | + epochs = 10 |
| 440 | + self._test(scheduler, targets, epochs) |
| 441 | + |
| 442 | + def test_reduce_lr_on_plateau1(self): |
| 443 | + for param_group in self.opt.param_groups: |
| 444 | + param_group['lr'] = 0.5 |
| 445 | + targets = [[0.5] * 20] |
| 446 | + metrics = [10 - i * 0.0167 for i in range(20)] |
| 447 | + scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min', |
| 448 | + threshold=0.01, patience=5, cooldown=5) |
| 449 | + epochs = 10 |
| 450 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 451 | + |
| 452 | + def test_reduce_lr_on_plateau2(self): |
| 453 | + for param_group in self.opt.param_groups: |
| 454 | + param_group['lr'] = 0.5 |
| 455 | + targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2] |
| 456 | + metrics = [10 - i * 0.0165 for i in range(22)] |
| 457 | + scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', |
| 458 | + mode='min', threshold=0.1) |
| 459 | + epochs = 22 |
| 460 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 461 | + |
| 462 | + def test_reduce_lr_on_plateau3(self): |
| 463 | + for param_group in self.opt.param_groups: |
| 464 | + param_group['lr'] = 0.5 |
| 465 | + targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4] |
| 466 | + metrics = [-0.8] * 2 + [-0.234] * 20 |
| 467 | + scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5, |
| 468 | + threshold_mode='abs') |
| 469 | + epochs = 22 |
| 470 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 471 | + |
| 472 | + def test_reduce_lr_on_plateau4(self): |
| 473 | + for param_group in self.opt.param_groups: |
| 474 | + param_group['lr'] = 0.5 |
| 475 | + targets = [[0.5] * 20] |
| 476 | + metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25 |
| 477 | + scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3, |
| 478 | + threshold_mode='rel', threshold=0.1) |
| 479 | + epochs = 20 |
| 480 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 481 | + |
| 482 | + def test_reduce_lr_on_plateau5(self): |
| 483 | + for param_group in self.opt.param_groups: |
| 484 | + param_group['lr'] = 0.5 |
| 485 | + targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] |
| 486 | + metrics = [1.5 * (1.005 ** i) for i in range(20)] |
| 487 | + scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', |
| 488 | + threshold=0.1, patience=5, cooldown=5) |
| 489 | + epochs = 20 |
| 490 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 491 | + |
| 492 | + def test_reduce_lr_on_plateau6(self): |
| 493 | + for param_group in self.opt.param_groups: |
| 494 | + param_group['lr'] = 0.5 |
| 495 | + targets = [[0.5] * 20] |
| 496 | + metrics = [1.5 * (0.85 ** i) for i in range(20)] |
| 497 | + scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', |
| 498 | + threshold=0.1) |
| 499 | + epochs = 20 |
| 500 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 501 | + |
| 502 | + def test_reduce_lr_on_plateau7(self): |
| 503 | + for param_group in self.opt.param_groups: |
| 504 | + param_group['lr'] = 0.5 |
| 505 | + targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] |
| 506 | + metrics = [1] * 7 + [0.6] + [0.5] * 12 |
| 507 | + scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', |
| 508 | + threshold=0.1, patience=5, cooldown=5) |
| 509 | + epochs = 20 |
| 510 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 511 | + |
| 512 | + def test_reduce_lr_on_plateau8(self): |
| 513 | + for param_group in self.opt.param_groups: |
| 514 | + param_group['lr'] = 0.5 |
| 515 | + targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14] |
| 516 | + metrics = [1.5 * (1.005 ** i) for i in range(20)] |
| 517 | + scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3], |
| 518 | + threshold=0.1, patience=5, cooldown=5) |
| 519 | + epochs = 20 |
| 520 | + self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| 521 | + |
| 522 | + def test_lambda_lr(self): |
| 523 | + self.opt.param_groups[0]['lr'] = 0.05 |
| 524 | + self.opt.param_groups[1]['lr'] = 0.4 |
| 525 | + targets = [[0.05 * (0.9 ** x) for x in range(10)], [0.4 * (0.8 ** x) for x in range(10)]] |
| 526 | + scheduler = LambdaLR(self.opt, |
| 527 | + lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2]) |
| 528 | + epochs = 10 |
| 529 | + self._test(scheduler, targets, epochs) |
| 530 | + |
| 531 | + def _test(self, scheduler, targets, epochs=10): |
| 532 | + for epoch in range(epochs): |
| 533 | + scheduler.step(epoch) |
| 534 | + for param_group, target in zip(self.opt.param_groups, targets): |
| 535 | + self.assertAlmostEqual(target[epoch], param_group['lr'], |
| 536 | + msg='LR is wrong in epoch {}: expected {}, got {}'.format( |
| 537 | + epoch, target[epoch], param_group['lr']), delta=1e-5) |
| 538 | + |
| 539 | + def _test_reduce_lr_on_plateau(self, scheduler, targets, metrics, epochs=10, verbose=False): |
| 540 | + for epoch in range(epochs): |
| 541 | + scheduler.step(metrics[epoch]) |
| 542 | + if verbose: |
| 543 | + print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr'])) |
| 544 | + for param_group, target in zip(self.opt.param_groups, targets): |
| 545 | + self.assertAlmostEqual(target[epoch], param_group['lr'], |
| 546 | + msg='LR is wrong in epoch {}: expected {}, got {}'.format( |
| 547 | + epoch, target[epoch], param_group['lr']), delta=1e-5) |
| 548 | + |
395 | 549 | if __name__ == '__main__': |
396 | 550 | run_tests() |
0 commit comments