[checkpoint]support generalized scheduler (#1222)

This commit is contained in:
Yi Zhao
2022-07-07 18:16:38 +08:00
committed by GitHub
parent a98319f023
commit 04537bf83e
4 changed files with 85 additions and 20 deletions

View File

@@ -2,6 +2,7 @@ from torch.optim.lr_scheduler import _LRScheduler
class _enable_get_lr_call:
def __init__(self, o):
self.o = o
@@ -33,6 +34,16 @@ class DelayerScheduler(_LRScheduler):
self.finished = False
super().__init__(optimizer, last_epoch)
def state_dict(self):
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
if isinstance(state_dict['after_scheduler'], _LRScheduler):
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
del state_dict['after_scheduler']
else:
raise NotImplementedError()
return state_dict
def get_lr(self):
if self.last_epoch >= self.delay_epochs:
if not self.finished:
@@ -73,6 +84,16 @@ class WarmupScheduler(_LRScheduler):
self.finished = False
super().__init__(optimizer, last_epoch)
def state_dict(self):
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
if isinstance(state_dict['after_scheduler'], _LRScheduler):
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
del state_dict['after_scheduler']
else:
raise NotImplementedError()
return state_dict
def get_lr(self):
if self.last_epoch >= self.warmup_epochs:
if not self.finished:
@@ -118,6 +139,16 @@ class WarmupDelayerScheduler(_LRScheduler):
self.finished = False
super().__init__(optimizer, last_epoch)
def state_dict(self):
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
if isinstance(state_dict['after_scheduler'], _LRScheduler):
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
del state_dict['after_scheduler']
else:
raise NotImplementedError()
return state_dict
def get_lr(self):
if self.last_epoch >= self.warmup_epochs + self.delay_epochs:
if not self.finished: