LR_Scheduler.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  2. ## Created by: Hang Zhang
  3. ## ECE Department, Rutgers University
  4. ## Email: zhang.hang@rutgers.edu
  5. ## Copyright (c) 2017
  6. ##
  7. ## This source code is licensed under the MIT-style license found in the
  8. ## LICENSE file in the root directory of this source tree
  9. ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  10. import math
  11. __all__ = ['LR_Scheduler', 'LR_Scheduler_Head']
  12. class LR_Scheduler(object):
  13. """Learning Rate Scheduler
  14. Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}``
  15. Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))``
  16. Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9``
  17. Args:
  18. args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`),
  19. :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs,
  20. :attr:`args.lr_step`
  21. iters_per_epoch: number of iterations per epoch
  22. """
  23. def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
  24. lr_step=0, warmup_epochs=0, quiet=False):
  25. self.mode = mode
  26. self.quiet = quiet
  27. if not quiet:
  28. print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs))
  29. if mode == 'step':
  30. assert lr_step
  31. self.base_lr = base_lr
  32. self.lr_step = lr_step
  33. self.iters_per_epoch = iters_per_epoch
  34. self.epoch = -1
  35. self.warmup_iters = warmup_epochs * iters_per_epoch
  36. self.total_iters = (num_epochs - warmup_epochs) * iters_per_epoch
  37. def __call__(self, optimizer, i, epoch, best_pred):
  38. T = epoch * self.iters_per_epoch + i
  39. # warm up lr schedule
  40. if self.warmup_iters > 0 and T < self.warmup_iters:
  41. lr = self.base_lr * 1.0 * T / self.warmup_iters
  42. elif self.mode == 'cos':
  43. T = T - self.warmup_iters
  44. lr = 0.5 * self.base_lr * (1 + math.cos(1.0 * T / self.total_iters * math.pi))
  45. elif self.mode == 'poly':
  46. T = T - self.warmup_iters
  47. lr = self.base_lr * pow((1 - 1.0 * T / self.total_iters), 0.9)
  48. elif self.mode == 'step':
  49. lr = self.base_lr * (0.1 ** (epoch // self.lr_step))
  50. else:
  51. raise NotImplemented
  52. if epoch > self.epoch and (epoch == 0 or best_pred > 0.0):
  53. if not self.quiet:
  54. print('\n=>Epoch %i, learning rate = %.4f, \
  55. previous best = %.4f' % (epoch, lr, best_pred))
  56. self.epoch = epoch
  57. assert lr >= 0
  58. self._adjust_learning_rate(optimizer, lr)
  59. def _adjust_learning_rate(self, optimizer, lr):
  60. for i in range(len(optimizer.param_groups)):
  61. optimizer.param_groups[i]['lr'] = lr
  62. class LR_Scheduler_Head(LR_Scheduler):
  63. """Incease the additional head LR to be 10 times"""
  64. def _adjust_learning_rate(self, optimizer, lr):
  65. if len(optimizer.param_groups) == 1:
  66. optimizer.param_groups[0]['lr'] = lr
  67. else:
  68. # enlarge the lr at the head
  69. optimizer.param_groups[0]['lr'] = lr
  70. for i in range(1, len(optimizer.param_groups)):
  71. optimizer.param_groups[i]['lr'] = lr * 10