mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -27,12 +27,7 @@ class MultiStepLR(_MultiStepLR):
|
||||
:type last_epoch: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1,
|
||||
num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
||||
if num_steps_per_epoch <= 0:
|
||||
raise ValueError(
|
||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
||||
milestones = [v * num_steps_per_epoch for v in milestones]
|
||||
def __init__(self, optimizer, total_steps: int, milestones: List[int] = None, gamma: float = 0.1, last_epoch: int = -1, **kwargs):
|
||||
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
|
||||
|
||||
|
||||
@@ -57,14 +52,11 @@ class MultiStepWarmupLR(WarmupScheduler):
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, milestones: List[int] = None,
|
||||
gamma: float = 0.1, num_steps_per_epoch: int = -1, last_epoch: int = -1, **kwargs):
|
||||
gamma: float = 0.1, last_epoch: int = -1, **kwargs):
|
||||
if len(milestones) == 0:
|
||||
raise ValueError('milestones cannot be empty')
|
||||
if num_steps_per_epoch <= 0:
|
||||
raise ValueError(
|
||||
f'num_steps_per_epoch must > 0, got {num_steps_per_epoch}')
|
||||
milestones = [v * num_steps_per_epoch - warmup_steps for v in milestones if v *
|
||||
num_steps_per_epoch >= warmup_steps]
|
||||
milestones = [
|
||||
v - warmup_steps for v in milestones if v >= warmup_steps]
|
||||
base_scheduler = _MultiStepLR(optimizer, milestones=milestones,
|
||||
gamma=gamma)
|
||||
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
|
||||
|
Reference in New Issue
Block a user