mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[checkpoint]support generalized scheduler (#1222)
This commit is contained in:
@@ -2,6 +2,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class _enable_get_lr_call:
|
||||
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
@@ -33,6 +34,16 @@ class DelayerScheduler(_LRScheduler):
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.delay_epochs:
|
||||
if not self.finished:
|
||||
@@ -73,6 +84,16 @@ class WarmupScheduler(_LRScheduler):
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs:
|
||||
if not self.finished:
|
||||
@@ -118,6 +139,16 @@ class WarmupDelayerScheduler(_LRScheduler):
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.warmup_epochs + self.delay_epochs:
|
||||
if not self.finished:
|
||||
|
Reference in New Issue
Block a user