mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -2,7 +2,6 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class _enable_get_lr_call:
|
||||
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
@@ -28,18 +27,18 @@ class DelayerScheduler(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1):
|
||||
if delay_epochs < 0:
|
||||
raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}')
|
||||
raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}")
|
||||
self.delay_epochs = delay_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
@@ -85,11 +84,11 @@ class WarmupScheduler(_LRScheduler):
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
@@ -130,9 +129,9 @@ class WarmupDelayerScheduler(_LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1):
|
||||
if delay_epochs < 0:
|
||||
raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}')
|
||||
raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}")
|
||||
if warmup_epochs < 0:
|
||||
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
|
||||
raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}")
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.delay_epochs = delay_epochs
|
||||
self.after_scheduler = after_scheduler
|
||||
@@ -140,11 +139,11 @@ class WarmupDelayerScheduler(_LRScheduler):
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'}
|
||||
if isinstance(state_dict['after_scheduler'], _LRScheduler):
|
||||
state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__
|
||||
state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict()
|
||||
del state_dict['after_scheduler']
|
||||
state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"}
|
||||
if isinstance(state_dict["after_scheduler"], _LRScheduler):
|
||||
state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__
|
||||
state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict()
|
||||
del state_dict["after_scheduler"]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return state_dict
|
||||
@@ -155,7 +154,7 @@ class WarmupDelayerScheduler(_LRScheduler):
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
# reset lr to base_lr
|
||||
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
|
||||
group['lr'] = base_lr
|
||||
group["lr"] = base_lr
|
||||
self.finished = True
|
||||
with _enable_get_lr_call(self.after_scheduler):
|
||||
return self.after_scheduler.get_lr()
|
||||
|
Reference in New Issue
Block a user