[refactor] moving grad acc logic to engine (#804)

This commit is contained in:
Jiarui Fang
2022-04-19 14:03:21 +08:00
committed by GitHub
parent 05d9ae5999
commit 681addb512
8 changed files with 26 additions and 20 deletions

View File

@@ -0,0 +1,50 @@
import torch.nn as nn
from typing import List
from colossalai.engine import BaseGradientHandler
from typing import Iterable
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler
__all__ = [
'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
'GradAccumGradientHandler'
]
def accumulate_gradient(model: nn.Module,
optimizer: Optimizer,
dataloader: Iterable,
accumulate_size: int,
gradient_handlers: List[BaseGradientHandler] = None,
lr_scheduler: _LRScheduler = None):
r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation.
Args:
model (:class:`torch.nn.Module`): your model object for gradient accumulation.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation.
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
your dataloader object, would be called like iter(dataloader)
accumulate_size (int): the number of steps to accumulate gradients
gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
list of gradient handler objects. Default is None.
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
More details about `gradient_handlers` could be found in
`Gradient_handler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/engine/gradient_handler>`_.
More details about `lr_scheduler` could be found
`lr_scheduler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/lr_scheduler>`_. and
`how to adjust learning rate <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
"""
optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model)
dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size)
if gradient_handlers is not None:
gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers]
if lr_scheduler is not None:
lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
return optimizer, dataloader, gradient_handlers, lr_scheduler

View File

@@ -0,0 +1,197 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from torch import Tensor
from typing import Iterable, Any
from colossalai.nn.optimizer import ColossalaiOptimizer
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from colossalai.utils import conditional_context
from colossalai.engine import BaseGradientHandler
class GradAccumOptimizer(ColossalaiOptimizer):
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps
before accumulation size is reached.
Args:
optim (:class:`torch.optim.Optimizer`): Your optimizer object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
model (:class:`torch.nn.Module`):
Your model object to check if it is DistributedDataParallel for special handling of no_sync() context.
"""
def __init__(self, optim: Optimizer, accumulate_size: int, model: nn.Module = None):
super().__init__(optim)
self.accumulate_size = accumulate_size
self.accumulate_step = 0
# handle pytorch ddp auto all reduce
self.model = model
self.is_torch_ddp = isinstance(self.model, DistributedDataParallel)
def zero_grad(self, *args, **kwargs):
if self.accumulate_step == 0:
self.optim.zero_grad(*args, **kwargs)
def step(self, *args, **kwargs):
if self.accumulate_step < self.accumulate_size:
return None
else:
self.accumulate_step = 0
return self.optim.step(*args, **kwargs)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.accumulate_step < self.accumulate_size:
pass
else:
self.optim.clip_grad_norm(model, max_norm)
def backward(self, loss: Tensor):
self.accumulate_step += 1
if self.is_torch_ddp:
no_sync = self.accumulate_step < self.accumulate_size
with conditional_context(self.model.no_sync(), enable=no_sync):
scaled_loss = loss / self.accumulate_size
self.optim.backward(scaled_loss)
else:
scaled_loss = loss / self.accumulate_size
self.optim.backward(scaled_loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
self.accumulate_step += 1
no_sync = self.is_torch_ddp and self.accumulate_step < self.accumulate_size
if no_sync:
with self.model.no_sync():
self.optim.backward_by_grad(tensor, grad)
else:
self.optim.backward_by_grad(tensor, grad)
class GradAccumDataloader:
"""A wrapper for dataloader to enable gradient accumulation by dropping the last incomplete steps.
Note:
The dataloader would drop the last incomplete steps for gradient accumulation.
For example, if a dataloader has 10 batches of data and accumulate size is 4. The model parameters will
be updated only twice at step 4 and step 8. The last two batches of data do not form a complete 4-step cycle.
Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader,
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
Args:
optim (``Iterable``): Your dataloader object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
def __init__(self, dataloader: Iterable, accumulate_size: int) -> None:
self.dataloader = dataloader
self.consume_remain_data = not isinstance(dataloader, DataLoader)
self.steps_per_epoch = len(dataloader) - len(dataloader) % accumulate_size
def __getattr__(self, __name: str) -> Any:
return getattr(self.dataloader, __name)
def __len__(self):
return self.steps_per_epoch
def __iter__(self):
self._cur_step = 0
self._dataiter = iter(self.dataloader)
return self
def __next__(self) -> Any:
if self._cur_step < self.steps_per_epoch:
self._cur_step += 1
if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
# this is to handle non standard pytorch dataloader
# such as dali dataloader
while True:
try:
_ = next(self._dataiter)
except StopIteration:
break
return next(self._dataiter)
else:
raise StopIteration
class GradAccumLrSchedulerByStep(_LRScheduler):
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps
before accumulation size is reached.
Args:
lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`):
Your ``lr_scheduler`` object for gradient accumulation.
accumulate_size (int): The number of steps to accumulate gradients.
"""
def __init__(self, lr_scheduler: _LRScheduler, accumulate_size: int) -> None:
self.lr_scheduler = lr_scheduler
self.accumulate_size = accumulate_size
self.accumulate_step = 0
@staticmethod
def compute_effective_steps_per_epoch(dataloader: Iterable, accumulate_size: int):
return len(dataloader) // accumulate_size
def __getattr__(self, __name: str) -> Any:
return getattr(self.lr_scheduler, __name)
def step(self, *args, **kwargs):
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass
else:
self.accumulate_step = 0
self.lr_scheduler.step(*args, **kwargs)
def get_lr(self):
return self.lr_scheduler.get_lr()
def get_last_lr(self):
return self.lr_scheduler.get_last_lr()
def print_lr(self, *args, **kwargs):
self.lr_scheduler.print_lr(*args, **kwargs)
def state_dict(self) -> dict:
return self.lr_scheduler.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
self.lr_scheduler.load_state_dict(state_dict)
class GradAccumGradientHandler:
r"""A wrapper for the gradient handler to enable gradient accumulation by skipping the steps
before accumulation size is reached.
Args:
grad_handler (:class:`colossalai.engine.BaseGradientHandler`):
Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
accumulate_size (int): The number of steps to accumulate gradients.
More details about ``gradient_handlers`` could be found in
`Gradient_handler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/engine/gradient_handler>`_.
"""
def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None:
assert isinstance(grad_handler, BaseGradientHandler), \
f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}'
self.grad_handler = grad_handler
self.accumulate_size = accumulate_size
self.accumulate_step = 0
def handle_gradient(self):
self.accumulate_step += 1
if self.accumulate_step < self.accumulate_size:
pass
else:
self.accumulate_step = 0
self.grad_handler.handle_gradient()

View File

@@ -12,6 +12,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from ._base_schedule import BaseSchedule
@@ -115,7 +116,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, (NaiveAMPModel)) or hasattr(model, 'colo_attr'):
if isinstance(model, (NaiveAMPModel, ShardedModelV2)):
self.dtype = torch.half
model = model.model
sig = inspect.signature(model.forward)
@@ -386,8 +387,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self.num_model_chunks = num_model_chunks
def pre_processing(self, engine):
# FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency.
if hasattr(engine.model, 'colo_attr'):
if isinstance(engine.model, ShardedModelV2):
self.dtype = torch.half
elif isinstance(engine.model[0], NaiveAMPModel):
self.dtype = torch.half