mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-22 07:14:09 +00:00
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
@@ -10,12 +9,11 @@ from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.builder import build_hooks
|
||||
from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path
|
||||
from colossalai.context import Config
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
from colossalai.nn.data import DataParallelSampler
|
||||
from colossalai.utils import MultiTimer
|
||||
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
|
||||
|
||||
class Trainer:
|
||||
@@ -30,43 +28,31 @@ class Trainer:
|
||||
:type hoooks_cfg: Config, optional
|
||||
:type verbose: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
engine: Engine,
|
||||
hooks_cfg: Optional[Config] = None,
|
||||
verbose: bool = False):
|
||||
verbose: bool = False,
|
||||
timer: MultiTimer = None):
|
||||
# training-ralated params
|
||||
self._engine = engine
|
||||
self._max_epochs = float('inf')
|
||||
self._max_steps = float('inf')
|
||||
self._max_epochs = 0
|
||||
self._cur_epoch = 0
|
||||
self._max_steps = 0
|
||||
self._cur_step = 0
|
||||
|
||||
# data-related params
|
||||
self._train_dataloader = None
|
||||
self._test_dataloader = None
|
||||
self._steps_per_epoch = 0
|
||||
|
||||
# misc params
|
||||
self._display_progress = False
|
||||
self._logger = get_global_dist_logger()
|
||||
self._verbose = verbose
|
||||
|
||||
# hooks can store states in this dict, and could be consumed by other hooks
|
||||
self.states = {}
|
||||
self.states = dict()
|
||||
|
||||
# build hooks
|
||||
self.hooks = list()
|
||||
if hooks_cfg is not None:
|
||||
for cfg in hooks_cfg:
|
||||
hook = build_hooks(cfg, self)
|
||||
self.hooks.append(hook)
|
||||
self.hooks.sort(key=lambda hook: hook.priority)
|
||||
if self._verbose:
|
||||
for hook in self.hooks:
|
||||
self._logger.info(
|
||||
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
|
||||
|
||||
# timer
|
||||
self._timer = get_global_multitimer()
|
||||
# multi-timer for time benchmarking
|
||||
self._timer = timer
|
||||
|
||||
@property
|
||||
def cur_epoch(self):
|
||||
@@ -74,13 +60,65 @@ class Trainer:
|
||||
"""
|
||||
return self._cur_epoch
|
||||
|
||||
@cur_epoch.setter
|
||||
def cur_epoch(self, epoch: int):
|
||||
"""Set how many epochs have been processed.
|
||||
"""
|
||||
# allow setter for training resumption
|
||||
self._cur_epoch = epoch
|
||||
|
||||
@property
|
||||
def cur_step(self):
|
||||
"""Returns how many iteration steps have been processed.
|
||||
"""
|
||||
return self._cur_step
|
||||
|
||||
def call_hooks(self, func, output=None):
|
||||
@property
|
||||
def max_epochs(self):
|
||||
return self._max_epochs
|
||||
|
||||
@property
|
||||
def max_steps(self):
|
||||
return self._max_steps
|
||||
|
||||
@property
|
||||
def steps_per_epoch(self):
|
||||
return self._steps_per_epoch
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
return self._engine
|
||||
|
||||
@engine.setter
|
||||
def engine(self, engine_: Engine):
|
||||
self._engine = engine_
|
||||
|
||||
def _set_current_step(self, epoch: int):
|
||||
"""Sets current step number.
|
||||
|
||||
:param epoch: Step number to be set
|
||||
:type epoch: int
|
||||
"""
|
||||
self._cur_step = epoch * self._steps_per_epoch
|
||||
|
||||
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
|
||||
"""Call timer funciton with a given timer name.
|
||||
|
||||
:param action: Function to be called on timer
|
||||
:type action: str
|
||||
:param item: Name of the timer
|
||||
:type item: str
|
||||
"""
|
||||
|
||||
if self._timer is not None:
|
||||
getattr(self._timer, action)(item, *args, **kwargs)
|
||||
|
||||
def _reset_states(self) -> None:
|
||||
"""Clear trainer states
|
||||
"""
|
||||
self.states = dict()
|
||||
|
||||
def _call_hooks(self, func, output=None):
|
||||
"""Calls specific hooks in the current time point.
|
||||
|
||||
:param func: A string represents the time point
|
||||
@@ -95,161 +133,186 @@ class Trainer:
|
||||
else:
|
||||
getattr(hook, func)(*output)
|
||||
|
||||
def exceed_max_step(self):
|
||||
"""Checks whether the trainer exceeds the maximum number of runnning iterations.
|
||||
@staticmethod
|
||||
def _should_display_progress(display_progress: bool):
|
||||
""" Only display progress on DP rank 0, TP rank 0 and PP last rank
|
||||
"""
|
||||
return self._cur_step >= self._max_steps
|
||||
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
"""Sets current epoch number.
|
||||
|
||||
:param epoch: Epoch number to be set
|
||||
:type epoch: int
|
||||
"""
|
||||
self._cur_epoch = epoch
|
||||
|
||||
def _recover_steps(self):
|
||||
step = self.cur_step * self._engine.schedule.num_steps
|
||||
self._cur_step = step
|
||||
|
||||
def _set_display_progress(self, display_progress: bool):
|
||||
self._display_progress = display_progress and is_dp_rank_0(
|
||||
) and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
|
||||
def _train_epoch(self, epoch: int = None):
|
||||
def _train_epoch(self,
|
||||
train_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False):
|
||||
# set sampler epoch
|
||||
if epoch is not None and \
|
||||
hasattr(self._engine.train_dataloader, 'sampler') and \
|
||||
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler):
|
||||
self._engine.train_dataloader.sampler.set_epoch(epoch)
|
||||
hasattr(train_dataloader, 'sampler') and \
|
||||
isinstance(train_dataloader.sampler, DataParallelSampler):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
# set training state
|
||||
self._engine.train()
|
||||
|
||||
progress = range(self._engine.schedule.num_steps)
|
||||
if self._display_progress:
|
||||
data_iter = iter(train_dataloader)
|
||||
progress = range(self._steps_per_epoch)
|
||||
if display_progress:
|
||||
if epoch is None:
|
||||
progress = tqdm(progress, desc='[Train]')
|
||||
else:
|
||||
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
|
||||
|
||||
# train 1 epoch
|
||||
self.call_hooks('before_train_epoch')
|
||||
self._timer.start('train-epoch')
|
||||
for _ in progress:
|
||||
self._call_hooks('before_train_epoch')
|
||||
self._call_timer(action='start', item='train-epoch')
|
||||
for i in progress:
|
||||
self._call_hooks('before_train_iter')
|
||||
self._call_timer(action='start', item='train-step')
|
||||
|
||||
if i == self._steps_per_epoch - 1:
|
||||
is_last_iteration = True
|
||||
else:
|
||||
is_last_iteration = False
|
||||
|
||||
# run 1 training step
|
||||
logits, label, loss = self._engine.step(data_iter, is_last_iteration)
|
||||
self._call_timer(action='stop', item='train-step', keep_in_history=True)
|
||||
self._call_hooks('after_train_iter', output=(logits, label, loss))
|
||||
|
||||
self._cur_step += 1
|
||||
|
||||
self.call_hooks('before_train_iter')
|
||||
self._timer.start('train-step')
|
||||
logits, label, loss = self._engine.step()
|
||||
self._timer.stop('train-step', keep_in_history=True)
|
||||
self.call_hooks('after_train_iter', output=(logits, label, loss))
|
||||
|
||||
if self.exceed_max_step():
|
||||
# stop when max iter is reached
|
||||
# stop when max iter is reached
|
||||
if self._exceed_max_step():
|
||||
break
|
||||
self._timer.stop('train-epoch', keep_in_history=True)
|
||||
self.call_hooks('after_train_epoch')
|
||||
self._timer.reset('train-step')
|
||||
|
||||
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
|
||||
self._call_hooks('after_train_epoch')
|
||||
self._call_timer(action='reset', item='train-step')
|
||||
|
||||
def _eval(self,
|
||||
test_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
return_loss: bool = True):
|
||||
display_progress: bool = False):
|
||||
# switch engine status
|
||||
self._engine.eval()
|
||||
|
||||
self.call_hooks('before_test')
|
||||
data_iter = iter(test_dataloader)
|
||||
num_steps = len(test_dataloader)
|
||||
|
||||
self._call_hooks('before_test')
|
||||
with torch.no_grad():
|
||||
# prepare progress bar
|
||||
progress = range(self._engine.schedule.num_steps)
|
||||
if self._display_progress:
|
||||
progress = range(num_steps)
|
||||
if display_progress:
|
||||
desc = 'Evaluation'
|
||||
if epoch is not None:
|
||||
desc = '[Epoch %d val]' % epoch
|
||||
progress = tqdm(progress, desc=desc)
|
||||
|
||||
self.call_hooks('before_test_epoch')
|
||||
self._timer.start('test-epoch')
|
||||
self._call_hooks('before_test_epoch')
|
||||
self._call_timer(action='start', item='test-epoch')
|
||||
for _ in progress:
|
||||
self.call_hooks('before_test_iter')
|
||||
self._timer.start('test-step')
|
||||
logits, label, loss = self._engine.step(
|
||||
return_loss=return_loss)
|
||||
self._timer.stop('test-step', keep_in_history=True)
|
||||
self.call_hooks('after_test_iter',
|
||||
output=(logits, label, loss))
|
||||
self._timer.stop('test-epoch', keep_in_history=True)
|
||||
self.call_hooks('after_test_epoch')
|
||||
self.call_hooks('after_test')
|
||||
self._timer.reset('test-step')
|
||||
self._timer.reset('test-epoch')
|
||||
self._call_hooks('before_test_iter')
|
||||
self._call_timer(action='start', item='test-step')
|
||||
logits, label, loss = self._engine.step(data_iter, return_loss=True)
|
||||
self._call_timer(action='stop', item='test-step', keep_in_history=True)
|
||||
self._call_hooks('after_test_iter',
|
||||
output=(logits, label, loss))
|
||||
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
|
||||
self._call_hooks('after_test_epoch')
|
||||
self._call_hooks('after_test')
|
||||
self._call_timer(action='reset', item='test-step')
|
||||
self._call_timer(action='reset', item='test-epoch')
|
||||
|
||||
def _exceed_max_step(self):
|
||||
return self._max_steps is not None and self._cur_step > self._max_steps
|
||||
|
||||
def fit(self,
|
||||
train_dataloader: DataLoader,
|
||||
test_dataloader: DataLoader = None,
|
||||
max_epochs: int = None,
|
||||
epochs: int,
|
||||
max_steps: int = None,
|
||||
test_dataloader: DataLoader = None,
|
||||
test_interval: int = 1,
|
||||
display_progress: bool = False):
|
||||
hooks_cfg: dict = None,
|
||||
display_progress: bool = False,
|
||||
):
|
||||
"""Trains the model to fit training data.
|
||||
|
||||
:param train_dataloader: DataLoader in training
|
||||
:param test_dataloader: DataLoader in testing
|
||||
:param max_epochs: Maximum number of epoches
|
||||
:param epochs: Maximum number of epoches
|
||||
:param max_steps: Maximum number of running iterations
|
||||
:param test_dataloader: DataLoader in testing
|
||||
:param test_interval: Interval of testing
|
||||
:param hooks_cfg: A list of hook configuration
|
||||
:param display_progress: If True, the training progress will be printed
|
||||
:type train_dataloader: DataLoader
|
||||
:type test_dataloader: DataLoader
|
||||
:type max_epochs: int
|
||||
:type epochs: int
|
||||
:type max_steps: int
|
||||
:type test_dataloader: DataLoader
|
||||
:type test_interval: int
|
||||
:type hooks_cfg: dict
|
||||
:type display_progress: bool
|
||||
:type gradient_accumulation: int
|
||||
"""
|
||||
|
||||
# prepare dataloaders
|
||||
self._train_dataloader = train_dataloader
|
||||
self._engine.set_dataloader(self._train_dataloader, train=True)
|
||||
self._engine.train()
|
||||
# set epochs and steps, consider gradient accumulation
|
||||
self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation
|
||||
self._max_steps = max_steps
|
||||
self._max_epochs = epochs
|
||||
|
||||
# check if testing is required
|
||||
should_test = False
|
||||
if test_dataloader is not None:
|
||||
self._test_dataloader = test_dataloader
|
||||
self._engine.set_dataloader(self._test_dataloader, train=False)
|
||||
should_test = True
|
||||
|
||||
# decide the
|
||||
if max_epochs is not None:
|
||||
self._max_epochs = max_epochs
|
||||
if max_steps is not None:
|
||||
self._max_steps = max_steps
|
||||
self._set_display_progress(display_progress)
|
||||
display_progress = self._should_display_progress(display_progress)
|
||||
|
||||
# reset hooks
|
||||
self._reset_states()
|
||||
self.hooks = list()
|
||||
|
||||
# build hooks
|
||||
if hooks_cfg is not None:
|
||||
for cfg in hooks_cfg:
|
||||
hook = build_hooks(cfg, self)
|
||||
self.hooks.append(hook)
|
||||
self.hooks.sort(key=lambda hook: hook.priority)
|
||||
if self._verbose:
|
||||
for hook in self.hooks:
|
||||
self._logger.info(
|
||||
f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
|
||||
self._logger.info("Lower value means higher priority for calling hook function")
|
||||
|
||||
# start train
|
||||
self.call_hooks('before_train')
|
||||
self._engine.train()
|
||||
self._call_hooks('before_train')
|
||||
|
||||
# recover step value if resuming training
|
||||
if self.cur_epoch != 0:
|
||||
self._recover_steps()
|
||||
|
||||
last_epoch = self._cur_epoch
|
||||
if self.cur_epoch != 0:
|
||||
self._set_current_step(last_epoch)
|
||||
|
||||
for epoch in range(last_epoch, self._max_epochs):
|
||||
self._cur_epoch += 1
|
||||
|
||||
for epoch in range(last_epoch, epochs):
|
||||
# train for one epoch
|
||||
self._train_epoch(epoch)
|
||||
self._train_epoch(
|
||||
train_dataloader=train_dataloader,
|
||||
epoch=epoch,
|
||||
display_progress=display_progress
|
||||
)
|
||||
|
||||
# start eval
|
||||
if should_test and epoch % test_interval == 0:
|
||||
self._eval(epoch, return_loss=True)
|
||||
self._eval(test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
epoch=epoch,
|
||||
)
|
||||
|
||||
self._cur_epoch += 1
|
||||
|
||||
# check for termination
|
||||
if self.exceed_max_step():
|
||||
if self._exceed_max_step():
|
||||
self._logger.info(
|
||||
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically")
|
||||
f"Max number of steps {max_steps} has been reached, training is stopped automatically")
|
||||
break
|
||||
self.call_hooks('after_train')
|
||||
self._timer.reset('train-epoch')
|
||||
self._call_hooks('after_train')
|
||||
self._call_timer('reset', 'train-epoch')
|
||||
|
||||
def evaluate(self,
|
||||
test_dataloader: DataLoader,
|
||||
@@ -261,15 +324,13 @@ class Trainer:
|
||||
:type test_dataloader: DataLoader
|
||||
:type display_progress: bool, optional
|
||||
"""
|
||||
# set dataloader
|
||||
self._test_dataloader = test_dataloader
|
||||
self._engine.set_dataloader(self._test_dataloader, train=True)
|
||||
|
||||
# set
|
||||
self._set_display_progress(display_progress)
|
||||
# set display
|
||||
display_progress = self._should_display_progress(display_progress)
|
||||
|
||||
# eval
|
||||
self._eval(return_loss=True)
|
||||
self._eval(test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
)
|
||||
|
||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||
@@ -289,45 +350,6 @@ class Trainer:
|
||||
# prepare a list of (data, label) to make it iterable
|
||||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
self._engine.set_dataloader(simple_dataloader)
|
||||
output, _, _ = self._engine.step(return_loss=False)
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self._engine.step(data_iter, return_loss=False)
|
||||
return output
|
||||
|
||||
def save(self, path: str, suffix: str = ''):
|
||||
"""Saves the model to a file.
|
||||
|
||||
:param path: Relative path of the file
|
||||
:param suffix: Suffix of the file
|
||||
:type path: str
|
||||
:type suffix: str, optional
|
||||
"""
|
||||
save_path = get_checkpoint_path(path,
|
||||
self._cur_epoch,
|
||||
suffix=suffix)
|
||||
save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(),
|
||||
self._engine.get_optimizer(),
|
||||
self._engine.get_lr_scheduler())
|
||||
|
||||
def load(self,
|
||||
path: str,
|
||||
finetune: bool = False,
|
||||
strict: bool = False):
|
||||
"""Loads parameters to the model from a file.
|
||||
|
||||
:param path: Relative path of the file
|
||||
:param finetune: Whether allows to load a part of the model
|
||||
:param strict: Whether loads a model that has the same shape of parameters
|
||||
:type path: str
|
||||
:type finetune: bool, optional
|
||||
:type strict: bool, optional
|
||||
"""
|
||||
last_epoch, _ = load_checkpoint(path,
|
||||
self._engine.get_model(),
|
||||
self._engine.get_optimizer(),
|
||||
self._engine.get_lr_scheduler(),
|
||||
finetune=finetune,
|
||||
strict=strict)
|
||||
if finetune:
|
||||
self.set_epoch(0)
|
||||
else:
|
||||
self.set_epoch(last_epoch)
|
||||
|
Reference in New Issue
Block a user