mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b7699
.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from .amp_type import AMP_TYPE
|
||||
from ._base_engine import Engine
|
||||
from .gradient_handler import *
|
||||
from .schedule import *
|
||||
from .amp import *
|
||||
|
||||
|
||||
__all__ = ['Engine']
|
||||
|
@@ -1,7 +1,9 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.builder import build_gradient_handler
|
||||
from colossalai.context import ParallelMode
|
||||
@@ -9,89 +11,103 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from torch.nn import Module
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .schedule import BaseSchedule, NoPipelineSchedule
|
||||
from .schedule import BaseSchedule
|
||||
|
||||
|
||||
class Engine:
|
||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
|
||||
It controls a iteration in training.
|
||||
|
||||
:param train_dataloader: Dataloader in training
|
||||
:param test_dataloader: Dataloader in evaluation
|
||||
:param model: The neural network model
|
||||
:param criterion: Criterion for calculating loss
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param lr_scheduler: Learning rate scheduler ajusting learning rate during the training or evaluation
|
||||
:param schedule: Running schedule in :meth:`step`
|
||||
:type train_dataloader: DataLoader, optional
|
||||
:type test_dataloader: DataLoader, optional
|
||||
:param step_schedule: Running schedule in :meth:`step`
|
||||
:param gradient_accumulation: Steps of gradient accumulation
|
||||
:param gradient_clipping: The norm of gradient clipping
|
||||
:type model: Module
|
||||
:type criterion: _Loss, optional
|
||||
:type optimizer: Optimizer, optional
|
||||
:type lr_scheduler: _LRScheduler, optional
|
||||
:type schedule: BaseSchedule, optional
|
||||
:type optimizer: Optimizer
|
||||
:type step_schedule: BaseSchedule, optional
|
||||
:type gradient_accumulation: int, optional
|
||||
:type gradient_clipping: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
train_dataloader: Optional[DataLoader] = None,
|
||||
test_dataloader: Optional[DataLoader] = None,
|
||||
model: Module = None,
|
||||
criterion: _Loss = None,
|
||||
optimizer: Optimizer = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
schedule: BaseSchedule = None):
|
||||
self.train_dataloader = train_dataloader
|
||||
self.test_dataloader = test_dataloader
|
||||
assert model is not None, "Engine requires a model"
|
||||
self.model = model
|
||||
self.criterion = criterion
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.schedule = schedule if schedule is not None \
|
||||
else NoPipelineSchedule()
|
||||
model: Module,
|
||||
optimizer: Optimizer,
|
||||
criterion: _Loss,
|
||||
step_schedule: BaseSchedule,
|
||||
gradient_handlers: list = None,
|
||||
gradient_accumulation: int = 1,
|
||||
gradient_clipping: float = 0.0,
|
||||
):
|
||||
self._model = model
|
||||
self._optimizer = optimizer
|
||||
self._criterion = criterion
|
||||
self._schedule = step_schedule
|
||||
|
||||
# schedule initialize
|
||||
self._schedule.initialize(model, optimizer)
|
||||
|
||||
# state
|
||||
self.training = True # default
|
||||
|
||||
# gradient accumulation
|
||||
assert gradient_accumulation > 0, 'gradient accumulation size must be larger than 0'
|
||||
self._grad_accum_size = gradient_accumulation
|
||||
self._grad_clip = gradient_clipping
|
||||
self._logger = get_global_dist_logger()
|
||||
|
||||
# build gradient handler
|
||||
self._gradient_handlers = []
|
||||
gradient_handler_cfg = []
|
||||
|
||||
if hasattr(gpc.config, 'gradient_handler'):
|
||||
assert isinstance(gpc.config.gradient_handler, list), \
|
||||
if gradient_handlers is not None:
|
||||
assert isinstance(gradient_handlers, list), \
|
||||
f'argument gradient_handler_cfg expected type list, ' \
|
||||
f'but got type {type(gpc.config.gradient_handler)}'
|
||||
gradient_handler_cfg = gpc.config.gradient_handler
|
||||
elif isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
|
||||
f'but got type {type(gradient_handlers)}'
|
||||
elif isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
gradient_handlers = [dict(type='ZeROGradientHandler')]
|
||||
self._logger.info(
|
||||
"Training with zero is detected, ZeROGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
elif gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(
|
||||
ParallelMode.DATA) > 1:
|
||||
gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
|
||||
gradient_handlers = [dict(type='DataParallelGradientHandler')]
|
||||
self._logger.info(
|
||||
"Data parallel training is detected, DataParallelGradientHandler is automatically "
|
||||
"added even though not specified in the configuration",
|
||||
ranks=[0])
|
||||
if len(gradient_handler_cfg) == 0:
|
||||
|
||||
if gradient_handlers is None:
|
||||
self._logger.warning(
|
||||
"No gradient handler is set up, please make sure you do not need "
|
||||
"to all-reduce the gradients after a training step.",
|
||||
ranks=[0])
|
||||
for cfg in gradient_handler_cfg:
|
||||
handler = build_gradient_handler(cfg, self.model, self.optimizer)
|
||||
self._gradient_handlers.append(handler)
|
||||
else:
|
||||
for cfg in gradient_handlers:
|
||||
handler = build_gradient_handler(cfg, model, optimizer)
|
||||
self._gradient_handlers.append(handler)
|
||||
|
||||
self.schedule.initialize(self.train_dataloader, self.model,
|
||||
self.criterion, self.optimizer,
|
||||
self.lr_scheduler)
|
||||
self.forward_only = False
|
||||
@property
|
||||
def model(self):
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
return self._criterion
|
||||
|
||||
@property
|
||||
def schedule(self):
|
||||
return self._schedule
|
||||
|
||||
@property
|
||||
def gradient_accumulation(self):
|
||||
return self._grad_accum_size
|
||||
|
||||
def handle_gradient(self):
|
||||
"""Handles all-reduce operations of gradients across different parallel groups.
|
||||
@@ -99,72 +115,62 @@ class Engine:
|
||||
for handler in self._gradient_handlers:
|
||||
handler.handle_gradient()
|
||||
|
||||
def set_dataloader(self, data: DataLoader, train: bool = True):
|
||||
"""Sets dataloader in training or evaluation.
|
||||
|
||||
:param data: Dataloader to be set
|
||||
:param train: Set training dataloader if True, otherwise evaluation dataloader
|
||||
:type data: DataLoader
|
||||
:type train: bool
|
||||
"""
|
||||
if train:
|
||||
self.train_dataloader = data
|
||||
else:
|
||||
self.test_dataloader = data
|
||||
|
||||
def get_model(self):
|
||||
"""Returns the neural network model in the engine.
|
||||
"""
|
||||
return self.model
|
||||
def get_optimizer(self):
|
||||
"""Returns optimizier in the engine.
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
"""Returns the learning rate scheduler in the engine.
|
||||
"""
|
||||
return self.lr_scheduler
|
||||
|
||||
def train(self):
|
||||
"""Sets the model to training mode.
|
||||
"""
|
||||
self.forward_only = False
|
||||
self.schedule.train(dataloader=self.train_dataloader, mode=True)
|
||||
self.training = True
|
||||
self._model.train()
|
||||
|
||||
def eval(self):
|
||||
"""Sets the model to evaluation mode.
|
||||
"""
|
||||
self.forward_only = True
|
||||
self.schedule.train(dataloader=self.test_dataloader, mode=False)
|
||||
self.training = False
|
||||
self._model.eval()
|
||||
|
||||
def is_train(self):
|
||||
"""Returns True if it is in training, otherwise False.
|
||||
"""
|
||||
return not self.forward_only
|
||||
|
||||
def get_lr(self):
|
||||
"""Gets current learning rate.
|
||||
"""
|
||||
return self.schedule.get_lr()
|
||||
|
||||
def step(self, return_loss=True):
|
||||
def step(self,
|
||||
data_iter,
|
||||
is_last_iteration: bool = False,
|
||||
return_loss=True):
|
||||
"""A running step based on the schedule. Usually, it runs a training or
|
||||
evaluation over a batch of dataset.
|
||||
|
||||
:param data_iter: Data iterator of the dataset
|
||||
:param is_last_iteration: If True, this iteration is the last iteration in the epoch
|
||||
:param return_loss: loss will be returned if True
|
||||
:type return_loss: bool
|
||||
:type data_iter: Iterator
|
||||
:type is_last_iteration: bool, optional
|
||||
:type return_loss: bool, optional
|
||||
:return: (output, lablel, loss)
|
||||
"""
|
||||
self.schedule.zero_grad(forward_only=self.forward_only)
|
||||
if self.training:
|
||||
self._optimizer.zero_grad()
|
||||
|
||||
output, label, loss = self.schedule.forward_backward_step(
|
||||
forward_only=self.forward_only, return_loss=return_loss)
|
||||
# differentiate training and eval with grad accum
|
||||
if self.training:
|
||||
for i in range(self._grad_accum_size):
|
||||
output, label, loss = self._schedule.forward_backward_step(
|
||||
data_iter, self._model, self._criterion, self._optimizer,
|
||||
forward_only=False,
|
||||
grad_accum_size=self._grad_accum_size,
|
||||
return_loss=return_loss)
|
||||
|
||||
if not self.forward_only:
|
||||
# all reduce gradients
|
||||
self.handle_gradient()
|
||||
if i == self._grad_accum_size - 1:
|
||||
# all reduce gradients
|
||||
self.handle_gradient()
|
||||
self._schedule.optimizer_step(self._model, self._optimizer, self._grad_clip)
|
||||
else:
|
||||
output, label, loss = self._schedule.forward_backward_step(
|
||||
data_iter, self._model, self._criterion, self._optimizer,
|
||||
forward_only=True,
|
||||
grad_accum_size=1,
|
||||
return_loss=return_loss)
|
||||
|
||||
self.schedule.step()
|
||||
# consume the remaining dataset left out due to gradient accumulation
|
||||
if is_last_iteration:
|
||||
while True:
|
||||
try:
|
||||
_ = next(data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
return output, label, loss
|
||||
|
2
colossalai/engine/amp/__init__.py
Normal file
2
colossalai/engine/amp/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .grad_scaler import GradScaler
|
||||
from .amp_type import AMP_TYPE
|
577
colossalai/engine/amp/grad_scaler.py
Normal file
577
colossalai/engine/amp/grad_scaler.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p
|
||||
import torch
|
||||
from collections import defaultdict, abc
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from colossalai.context import ParallelMode
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class _MultiDeviceReplicator(object):
|
||||
"""
|
||||
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||
"""
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert master_tensor.is_cuda or master_tensor.device.type == 'xla'
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
def get(self, device) -> torch.Tensor:
|
||||
retval = self._per_device_tensors.get(device, None)
|
||||
if retval is None:
|
||||
retval = self.master.to(
|
||||
device=device, non_blocking=True, copy=True)
|
||||
self._per_device_tensors[device] = retval
|
||||
return retval
|
||||
|
||||
|
||||
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
|
||||
# as well as associated "enum" values. Prefers defining these at top level because
|
||||
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
|
||||
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
|
||||
# causes a circular reference, which we'd rather avoid.
|
||||
class OptState(Enum):
|
||||
READY = 0
|
||||
UNSCALED = 1
|
||||
STEPPED = 2
|
||||
|
||||
|
||||
def _refresh_per_optimizer_state():
|
||||
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||
|
||||
|
||||
class GradScaler(object):
|
||||
_scale: Optional[torch.Tensor]
|
||||
_grows_tracker: Optional[torch.Tensor]
|
||||
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||
"""
|
||||
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||
conveniently.
|
||||
|
||||
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||
|
||||
Example::
|
||||
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
scaler = GradScaler()
|
||||
|
||||
for epoch in epochs:
|
||||
for input, target in data:
|
||||
optimizer.zero_grad()
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# scaler.step() first unscales gradients of the optimizer's params.
|
||||
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||
# otherwise, optimizer.step() is skipped.
|
||||
scaler.step(optimizer)
|
||||
|
||||
# Updates the scale for next iteration.
|
||||
scaler.update()
|
||||
|
||||
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||
and multiple losses/optimizers.
|
||||
|
||||
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||
without incurring inf or NaN gradient values.
|
||||
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||
|
||||
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||
|
||||
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||
``growth_factor``.
|
||||
|
||||
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||
|
||||
Args:
|
||||
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||
enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_scale=2.**16,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
enabled=True):
|
||||
if enabled and not torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
||||
self._enabled = False
|
||||
else:
|
||||
self._enabled = enabled
|
||||
|
||||
if self._enabled:
|
||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||
|
||||
self._init_scale = init_scale
|
||||
# self._scale will be lazily initialized during the first call to scale()
|
||||
self._scale = None
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker = None
|
||||
self._per_optimizer_states = defaultdict(
|
||||
_refresh_per_optimizer_state)
|
||||
|
||||
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(
|
||||
funcname) + fix
|
||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
|
||||
funcname) + fix
|
||||
return (self._scale, self._growth_tracker)
|
||||
|
||||
def _lazy_init_scale_growth_tracker(self, dev):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = torch.full(
|
||||
(1,), self._init_scale, dtype=torch.float32, device=dev)
|
||||
self._growth_tracker = torch.full(
|
||||
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert outputs.is_cuda or outputs.device.type == 'xla'
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||
|
||||
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||
# holds a reference that can be overwritten by apply_scale
|
||||
stash: List[_MultiDeviceReplicator] = []
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert val.is_cuda or val.device.type == 'xla'
|
||||
if len(stash) == 0:
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(val.device)
|
||||
assert self._scale is not None
|
||||
stash.append(_MultiDeviceReplicator(self._scale))
|
||||
return val * stash[0].get(val.device)
|
||||
elif isinstance(val, abc.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, list) or isinstance(val, tuple):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError(
|
||||
"outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
|
||||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||
|
||||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||
# However, we don't know their devices or dtypes in advance.
|
||||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads = defaultdict(
|
||||
lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||
raise ValueError(
|
||||
"Attempting to unscale FP16 gradients.")
|
||||
if param.grad.is_sparse:
|
||||
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||
# so we should check the coalesced _values().
|
||||
if param.grad.dtype is torch.float16:
|
||||
param.grad = param.grad.coalesce()
|
||||
to_unscale = param.grad._values()
|
||||
else:
|
||||
to_unscale = param.grad
|
||||
|
||||
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
||||
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
|
||||
to_unscale)
|
||||
|
||||
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads,
|
||||
per_device_found_inf.get(
|
||||
device),
|
||||
per_device_inv_scale.get(device))
|
||||
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for tensor in per_device_found_inf._per_device_tensors.values():
|
||||
dist.all_reduce(tensor, op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
return per_device_found_inf._per_device_tensors
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
"""
|
||||
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||
|
||||
:meth:`unscale_` is optional, serving cases where you need to
|
||||
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||
between the backward pass(es) and :meth:`step`.
|
||||
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||
|
||||
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||
|
||||
...
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||
|
||||
.. note::
|
||||
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||
|
||||
.. warning::
|
||||
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
self._check_scale_growth_tracker("unscale_")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||
raise RuntimeError(
|
||||
"unscale_() has already been called on this optimizer since the last update().")
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||
optimizer, inv_scale, found_inf, False)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||
retval = None
|
||||
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
||||
retval = optimizer.step(*args, **kwargs)
|
||||
return retval
|
||||
|
||||
def step(self, optimizer, *args, **kwargs):
|
||||
"""
|
||||
:meth:`step` carries out the following two operations:
|
||||
|
||||
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||
|
||||
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||
|
||||
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||
args: Any arguments.
|
||||
kwargs: Any keyword arguments.
|
||||
|
||||
.. warning::
|
||||
Closure use is not currently supported.
|
||||
"""
|
||||
if (not self._enabled):
|
||||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
raise RuntimeError(
|
||||
"Closure use is not currently supported if GradScaler is enabled.")
|
||||
|
||||
self._check_scale_growth_tracker("step")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError(
|
||||
"step() has already been called since the last update().")
|
||||
|
||||
retval = None
|
||||
|
||||
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
|
||||
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||
# The contract with custom optimizers is that their step() should accept an additional,
|
||||
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||
# it can query its own state, invoke unscale_ on itself, etc
|
||||
retval = optimizer.step(*args, **dict(kwargs, grad_scaler=self))
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
return retval
|
||||
|
||||
if optimizer_state["stage"] is OptState.READY:
|
||||
self.unscale_(optimizer)
|
||||
|
||||
assert len(optimizer_state["found_inf_per_device"]
|
||||
) > 0, "No inf checks were recorded for this optimizer."
|
||||
|
||||
retval = self._maybe_opt_step(
|
||||
optimizer, optimizer_state, *args, **kwargs)
|
||||
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
|
||||
return retval
|
||||
|
||||
def update(self, new_scale=None):
|
||||
"""
|
||||
Updates the scale factor.
|
||||
|
||||
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||
the scale is multiplied by ``growth_factor`` to increase it.
|
||||
|
||||
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||
affect the scale GradScaler uses internally.)
|
||||
|
||||
Args:
|
||||
new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
|
||||
|
||||
.. warning::
|
||||
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||
been invoked for all optimizers used this iteration.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||
|
||||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||
# type: ignore[attr-defined]
|
||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
|
||||
for state in self._per_optimizer_states.values()
|
||||
for found_inf in state["found_inf_per_device"].values()]
|
||||
|
||||
assert len(
|
||||
found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
torch._amp_update_scale_(_scale,
|
||||
_growth_tracker,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval)
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _get_scale_async(self):
|
||||
return self._scale
|
||||
|
||||
def get_scale(self):
|
||||
"""
|
||||
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||
|
||||
.. warning::
|
||||
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||
"""
|
||||
if self._enabled:
|
||||
return self._init_scale if self._scale is None else self._get_scale_async().item()
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
def get_growth_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale growth factor.
|
||||
"""
|
||||
return self._growth_factor
|
||||
|
||||
def set_growth_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale growth factor.
|
||||
"""
|
||||
self._growth_factor = new_factor
|
||||
|
||||
def get_backoff_factor(self):
|
||||
r"""
|
||||
Returns a Python float containing the scale backoff factor.
|
||||
"""
|
||||
return self._backoff_factor
|
||||
|
||||
def set_backoff_factor(self, new_factor):
|
||||
r"""
|
||||
Args:
|
||||
new_scale (float): Value to use as the new scale backoff factor.
|
||||
"""
|
||||
self._backoff_factor = new_factor
|
||||
|
||||
def get_growth_interval(self):
|
||||
r"""
|
||||
Returns a Python int containing the growth interval.
|
||||
"""
|
||||
return self._growth_interval
|
||||
|
||||
def set_growth_interval(self, new_interval):
|
||||
r"""
|
||||
Args:
|
||||
new_interval (int): Value to use as the new growth interval.
|
||||
"""
|
||||
self._growth_interval = new_interval
|
||||
|
||||
def _get_growth_tracker(self):
|
||||
if self._enabled:
|
||||
return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def is_enabled(self):
|
||||
r"""
|
||||
Returns a bool indicating whether this instance is enabled.
|
||||
"""
|
||||
return self._enabled
|
||||
|
||||
def state_dict(self):
|
||||
r"""
|
||||
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||
|
||||
* ``"scale"`` - a Python float containing the current scale
|
||||
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||
|
||||
If this instance is not enabled, returns an empty dict.
|
||||
|
||||
.. note::
|
||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return {"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||
|
||||
Args:
|
||||
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
if len(state_dict) == 0:
|
||||
raise RuntimeError("The source state dict is empty, possibly because it was saved "
|
||||
"from a disabled instance of GradScaler.")
|
||||
|
||||
self._init_scale = state_dict["scale"]
|
||||
if self._scale is not None:
|
||||
self._scale.fill_(state_dict["scale"])
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._growth_interval = state_dict["growth_interval"]
|
||||
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||
if self._growth_tracker is not None:
|
||||
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
if self._enabled:
|
||||
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
|
||||
"of an iteration, or at the end after scaler.update()."
|
||||
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||
state['_init_scale'] = self.get_scale()
|
||||
state['_init_growth_tracker'] = self._get_growth_tracker()
|
||||
state['_scale'] = None
|
||||
state['_growth_tracker'] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = torch.full(
|
||||
(1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||
|
||||
def _found_inf_per_device(self, optimizer):
|
||||
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
@@ -5,125 +5,85 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class BaseSchedule(ABC):
|
||||
"""A basic helper class to control the process of training or evaluation.
|
||||
It mainly composes of forward_backward_step for gradient backward and
|
||||
optimizer_step for parameters update.
|
||||
For the convenience to enable FP16, we aggreate all codes that contain the
|
||||
control of FP16 in class schedule.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
self.logger = get_global_dist_logger()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_steps(self):
|
||||
"""The number of batches in training or evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def initialize(self,
|
||||
dataloader=None,
|
||||
model=None,
|
||||
criterion=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None):
|
||||
"""Initializes the schedule and set parameters before running.
|
||||
|
||||
:param dataloader: DataLoader in training or evaluation
|
||||
:param model: The neural network model
|
||||
:param criterion: Criterion for calculating loss
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param lr_scheduler: Learning rate scheduler in the process
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
assert model is not None, "Schedule requires a model"
|
||||
self.model = model
|
||||
assert criterion is not None, "Schedule requires a criterion"
|
||||
self.criterion = criterion
|
||||
assert optimizer is not None, "Schedule requires an optimizer"
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.initialized = True
|
||||
|
||||
def check_initialized(self):
|
||||
"""Checks whether the schedule is initialized.
|
||||
"""
|
||||
assert self.initialized, \
|
||||
'Schedule is not initialized. Call schedule.initialize(...) before using it.'
|
||||
|
||||
def load_batch(self):
|
||||
"""Loads a batch of dataset. It returns the data and labels which are
|
||||
already in the same GPU as where the model's.
|
||||
|
||||
:return: (data, label)
|
||||
:rtype: (Tensor, Tensor)
|
||||
"""
|
||||
self.check_initialized()
|
||||
if self.data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
data, label = next(self.data_iter)
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
@staticmethod
|
||||
def _move_tensor(element):
|
||||
if torch.is_tensor(element):
|
||||
if not element.is_cuda:
|
||||
return element.to(get_current_device()).detach()
|
||||
return element
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, (
|
||||
tuple,
|
||||
list,
|
||||
)):
|
||||
data = tuple([
|
||||
d.to(get_current_device()).detach() for d in data
|
||||
if torch.is_tensor(d)
|
||||
])
|
||||
if isinstance(data, (tuple, list)):
|
||||
data = tuple([self._move_tensor(d) for d in data])
|
||||
elif torch.is_tensor(data):
|
||||
data = data.to(get_current_device()).detach()
|
||||
return data
|
||||
|
||||
def train(self, dataloader=None, mode=True):
|
||||
"""Sets the dataloader to be used and turn the model to
|
||||
training or evaluation mode.
|
||||
def load_batch(self, data_iter):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
already in the same GPU as where the model's.
|
||||
|
||||
:param dataloader: Dataloader to be used
|
||||
:param mode: If True, the model will set as training mode. Otherwise, evaluation mode.
|
||||
:return: (data, label)
|
||||
:rtype: (Tensor, Tensor)
|
||||
"""
|
||||
self.check_initialized()
|
||||
if mode:
|
||||
self.model.train()
|
||||
else:
|
||||
self.model.eval()
|
||||
if dataloader is not None:
|
||||
self.dataloader = dataloader
|
||||
self.data_iter = iter(dataloader)
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
data, label = next(data_iter)
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
|
||||
def zero_grad(self, forward_only=False):
|
||||
"""Cleans gradients with the optimizer.
|
||||
"""
|
||||
if not forward_only:
|
||||
self.check_initialized()
|
||||
self.optimizer.zero_grad()
|
||||
def initialize(self, model, optimizer):
|
||||
"""Initializes the model and the optimizer before training.
|
||||
This is often used in FP16 training.
|
||||
|
||||
def get_lr(self):
|
||||
"""Returns the current learning rate.
|
||||
:param model: The neural network model
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
"""
|
||||
if self.lr_scheduler is not None:
|
||||
return self.lr_scheduler.get_lr()[0]
|
||||
else:
|
||||
return self.optimizer.param_groups[0]['lr']
|
||||
|
||||
def step(self):
|
||||
"""Updates the parameters and learning rate with the optimizer.
|
||||
"""
|
||||
self.check_initialized()
|
||||
self.optimizer.step()
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
return model, optimizer
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||
def forward_backward_step(self,
|
||||
data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer=None,
|
||||
forward_only=False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss=True):
|
||||
"""The process function over a batch of dataset for training or evaluation.
|
||||
|
||||
:param forward_only: If True, the process won't include backward.
|
||||
:param return_loss: If False, the loss won't be returned.
|
||||
:param data_iter: Data iterator of the dataset
|
||||
:param model: Model used in training or evaluation
|
||||
:param optimizer: Optimizer used in training or evaluation
|
||||
:param criterion: Loss function
|
||||
:param forward_only: If True, the process won't include backward
|
||||
:param grad_accum_size: Steps of gradient accumulation
|
||||
:param return_loss: If False, the loss won't be returned
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||
"""Updates the parameters with the optimizer.
|
||||
|
||||
:param model: The neural network model
|
||||
:param optimizer: Optimizer for updating the parameters
|
||||
:param grad_clipping: The norm of gradient clipping
|
||||
:type grad_clipping: float, optional
|
||||
"""
|
||||
pass
|
||||
|
@@ -4,19 +4,24 @@
|
||||
try:
|
||||
import apex.amp as apex_amp
|
||||
except:
|
||||
print('apex is required for mixed precision training')
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch.cuda.amp as torch_amp
|
||||
except:
|
||||
print('PyTorch amp is not supported with the current PyTorch version')
|
||||
pass
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.amp_type import AMP_TYPE
|
||||
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from ._utils import convert_to_fp16
|
||||
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._utils import convert_to_fp16, convert_to_fp32
|
||||
from ..amp import AMP_TYPE, GradScaler
|
||||
|
||||
|
||||
class NoPipelineSchedule(BaseSchedule):
|
||||
@@ -30,6 +35,7 @@ class NoPipelineSchedule(BaseSchedule):
|
||||
:type amp_type: AMP_TYPE
|
||||
:type amp_config: dict
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
amp_type: AMP_TYPE = None,
|
||||
@@ -41,12 +47,6 @@ class NoPipelineSchedule(BaseSchedule):
|
||||
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
|
||||
'unrecognised value for argument fp16, it can only be None, torch or apex'
|
||||
|
||||
# LSG: check compatibility
|
||||
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
|
||||
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
|
||||
ParallelMode.TENSOR) > 1:
|
||||
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
|
||||
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
|
||||
self.use_zero_level_2_3 = False
|
||||
|
||||
if amp_type is not None:
|
||||
@@ -79,107 +79,110 @@ class NoPipelineSchedule(BaseSchedule):
|
||||
self.fp16 = False
|
||||
self.amp_type = None
|
||||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
def initialize(self, model: nn.Module, optimizer: Optimizer):
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
self.use_zero_level_2_3 = True
|
||||
assert self.amp_type != AMP_TYPE.PARALLEL, 'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||
assert self.amp_type != AMP_TYPE.PARALLEL, \
|
||||
'ZeRO Level 2 and 3 are mutually exclusive with AMP_TYPE.PARALLEL'
|
||||
|
||||
if self.fp16:
|
||||
if self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler = torch_amp.GradScaler(**self.amp_cfg)
|
||||
self._torch_amp_scaler = GradScaler(**self.amp_cfg)
|
||||
elif self.amp_type == AMP_TYPE.APEX:
|
||||
self.model, self.optimizer = apex_amp.initialize(
|
||||
self.model, self.optimizer, **self.amp_cfg)
|
||||
model, optimizer = apex_amp.initialize(model, optimizer, **self.amp_cfg)
|
||||
|
||||
def forward_backward_step(self, forward_only=False, return_loss=True):
|
||||
return model, optimizer
|
||||
|
||||
def forward_backward_step(self,
|
||||
data_iter: Iterable,
|
||||
model: nn.Module,
|
||||
criterion: nn.modules.loss._Loss,
|
||||
optimizer: Optimizer = None,
|
||||
forward_only: bool = False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss: bool = True):
|
||||
"""The process function that loads loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
||||
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
|
||||
:param model: Model for training and inference
|
||||
:param criterion: Loss function for training
|
||||
:param optimizer: Optimizer used for training
|
||||
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
|
||||
:param grad_accum_size: The number of iterations for gradient accumulation
|
||||
:param return_loss: Loss will be returned if True
|
||||
:type data_iter: Iterator
|
||||
:type model: torch.nn.Module
|
||||
:type criterion: torch.nn.modules.loss._Loss
|
||||
:type optimizer: torch.optim.Optimizer
|
||||
:type forward_only: bool, optional
|
||||
:type grad_accum_size: int
|
||||
:type return_loss: bool, optional
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
|
||||
data, label = self.load_batch()
|
||||
data, label = self.load_batch(data_iter)
|
||||
loss = None
|
||||
|
||||
# LSG: leave for debug, make sure dataloader is deterministic
|
||||
# if forward_only:
|
||||
# img = data[0]
|
||||
# rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
# world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
# group = gpc.get_group(ParallelMode.DATA)
|
||||
# input_list = [img.clone() for _ in range(world_size)]
|
||||
# output_list = [torch.empty_like(img) for _ in range(world_size)]
|
||||
# output_list[rank] = img.clone()
|
||||
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
|
||||
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
|
||||
|
||||
# forward
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
with torch_amp.autocast():
|
||||
output = self.model(*data)
|
||||
output = model(*data)
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = self.criterion(*output, *label)
|
||||
loss = criterion(*output, *label)
|
||||
else:
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
data = convert_to_fp16(data)
|
||||
|
||||
output = self.model(*data)
|
||||
output = model(*data)
|
||||
|
||||
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output = convert_to_fp32(output)
|
||||
|
||||
if not isinstance(output, (tuple, list)):
|
||||
output = (output,)
|
||||
if return_loss:
|
||||
loss = self.criterion(*output, *label)
|
||||
loss = criterion(*output, *label)
|
||||
|
||||
loss /= grad_accum_size
|
||||
|
||||
if not forward_only:
|
||||
# backward
|
||||
if self.use_zero_level_2_3:
|
||||
self.optimizer.backward(loss)
|
||||
optimizer.backward(loss)
|
||||
elif self.fp16:
|
||||
if self.amp_type == AMP_TYPE.APEX:
|
||||
with apex_amp.scale_loss(loss,
|
||||
self.optimizer) as scaled_loss:
|
||||
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler.scale(loss).backward()
|
||||
elif self.amp_type == AMP_TYPE.PARALLEL:
|
||||
loss = self.optimizer.scale_loss(loss)
|
||||
loss = optimizer.scale_loss(loss)
|
||||
loss.backward()
|
||||
# scale back to display the original value in logs
|
||||
loss.div_(self.optimizer.grad_scaler.scale)
|
||||
loss.div_(optimizer.grad_scaler.scale)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if return_loss:
|
||||
return output, label, loss
|
||||
return output, label, loss * grad_accum_size
|
||||
else:
|
||||
return output, None, None
|
||||
|
||||
def step(self):
|
||||
def optimizer_step(self, model: nn.Module, optimizer: Optimizer, grad_clipping: float = 0.0):
|
||||
# step optimizer
|
||||
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
|
||||
self._torch_amp_scaler.step(self.optimizer)
|
||||
if grad_clipping > 0.0:
|
||||
self._torch_amp_scaler.unscale_(optimizer)
|
||||
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||
self._torch_amp_scaler.step(optimizer)
|
||||
self._torch_amp_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
# update lr scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if not self.fp16 and not self.use_zero_level_2_3 and grad_clipping > 0.0:
|
||||
clip_grad_norm_fp32(model.parameters(), grad_clipping)
|
||||
optimizer.step()
|
||||
|
@@ -15,7 +15,7 @@ from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
|
||||
from colossalai.utils import get_current_device
|
||||
from ._base_schedule import BaseSchedule
|
||||
from ._utils import convert_to_fp16
|
||||
from ..amp_type import AMP_TYPE
|
||||
from ..amp import AMP_TYPE
|
||||
|
||||
|
||||
def squeeze(x: Union[Tensor, tuple, list]):
|
||||
@@ -93,12 +93,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
)
|
||||
|
||||
# Pipeline schedule just puts data in memory
|
||||
def load_batch(self):
|
||||
self.check_initialized()
|
||||
if self.data_iter is None:
|
||||
def load_batch(self, data_iter):
|
||||
if data_iter is None:
|
||||
raise RuntimeError('Dataloader is not defined.')
|
||||
self.batch_pos = 0
|
||||
data, label = next(self.data_iter)
|
||||
data, label = next(data_iter)
|
||||
self.batch_data, self.batch_label = \
|
||||
self._move_to_device(data), self._move_to_device(label)
|
||||
batch_size = self.batch_data.shape[0]
|
||||
@@ -117,23 +116,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
self.batch_pos += self.microbatch_size
|
||||
return (data,), (label,)
|
||||
|
||||
@property
|
||||
def num_steps(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
def initialize(self,
|
||||
dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=None):
|
||||
super().initialize(dataloader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler=lr_scheduler)
|
||||
if isinstance(self.optimizer, (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)):
|
||||
def initialize(self, model, optimizer):
|
||||
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
|
||||
raise TypeError(
|
||||
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
|
||||
)
|
||||
@@ -145,7 +129,8 @@ class PipelineSchedule(BaseSchedule):
|
||||
'default tensor dtype is set to torch.half for fp16 training',
|
||||
ranks=[0])
|
||||
|
||||
def forward_step(self, input_tensor, return_tensors, return_loss=True):
|
||||
def forward_step(self, model, criterion, input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=True):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
@@ -156,14 +141,14 @@ class PipelineSchedule(BaseSchedule):
|
||||
if self.amp_type == AMP_TYPE.PARALLEL:
|
||||
input_tensor = convert_to_fp16(input_tensor)
|
||||
input_tensor = squeeze(input_tensor)
|
||||
output_tensor = self.model(input_tensor)
|
||||
output_tensor = model(input_tensor)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_loss:
|
||||
input_tensor, label = self.load_micro_batch()
|
||||
loss_reduced = self.criterion(output_tensor, *
|
||||
label) / self.num_microbatches
|
||||
loss_reduced = criterion(output_tensor, *label) \
|
||||
/ (self.num_microbatches * grad_accum_size)
|
||||
return_tensors.append(
|
||||
tuple((output_tensor, label[0], loss_reduced)))
|
||||
return loss_reduced
|
||||
@@ -174,7 +159,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
def backward_step(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
def backward_step(self, optimizer, input_tensor, output_tensor, output_tensor_grad):
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
Returns the gradients with respect to the input tensor (None if first stage).
|
||||
@@ -187,7 +172,7 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
# Backward pass.
|
||||
if output_tensor_grad is None and self.amp_type == AMP_TYPE.PARALLEL:
|
||||
output_tensor = self.optimizer.scale_loss(output_tensor)
|
||||
output_tensor = optimizer.scale_loss(output_tensor)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
@@ -197,17 +182,24 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
return input_tensor_grad
|
||||
|
||||
def forward_backward_step(self, forward_only=True, return_loss=True):
|
||||
def forward_backward_step(self,
|
||||
data_iter,
|
||||
model,
|
||||
criterion,
|
||||
optimizer=None,
|
||||
forward_only=False,
|
||||
grad_accum_size: int = 1,
|
||||
return_loss=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
|
||||
:return: (output, label, loss)
|
||||
"""
|
||||
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
|
||||
self.load_batch()
|
||||
self.load_batch(data_iter)
|
||||
num_warmup_microbatches = \
|
||||
(gpc.get_world_size(ParallelMode.PIPELINE) -
|
||||
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||
@@ -233,9 +225,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = recv_tensor_meta(ft_shape)
|
||||
input_tensor = recv_forward(ft_shape)
|
||||
output_tensor = self.forward_step(input_tensor,
|
||||
return_tensors,
|
||||
return_loss=return_loss)
|
||||
output_tensor = self.forward_step(
|
||||
model, criterion,
|
||||
input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=return_loss
|
||||
)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
fs_checker = send_tensor_meta(output_tensor, fs_checker)
|
||||
@@ -257,9 +251,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = self.forward_step(input_tensor,
|
||||
return_tensors,
|
||||
return_loss=return_loss)
|
||||
output_tensor = self.forward_step(
|
||||
model, criterion,
|
||||
input_tensor, return_tensors,
|
||||
grad_accum_size, return_loss=return_loss
|
||||
)
|
||||
if forward_only:
|
||||
send_forward(output_tensor)
|
||||
|
||||
@@ -279,9 +275,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor,
|
||||
output_tensor,
|
||||
output_tensor_grad)
|
||||
input_tensor_grad = self.backward_step(
|
||||
optimizer,
|
||||
input_tensor, output_tensor,
|
||||
output_tensor_grad
|
||||
)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
@@ -298,9 +296,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
|
||||
output_tensor_grad = recv_backward(bt_shape)
|
||||
|
||||
input_tensor_grad = self.backward_step(input_tensor,
|
||||
output_tensor,
|
||||
output_tensor_grad)
|
||||
input_tensor_grad = self.backward_step(
|
||||
optimizer,
|
||||
input_tensor, output_tensor,
|
||||
output_tensor_grad
|
||||
)
|
||||
|
||||
send_backward(input_tensor_grad)
|
||||
|
||||
@@ -309,8 +309,11 @@ class PipelineSchedule(BaseSchedule):
|
||||
output, label, loss = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
sum(loss))
|
||||
sum(loss) * grad_accum_size)
|
||||
else:
|
||||
return tuple((torch.cat(return_tensors, dim=0), None, None))
|
||||
else:
|
||||
return tuple((None, None, None))
|
||||
|
||||
def optimizer_step(self, model, optimizer, grad_clipping: float = 0.0):
|
||||
optimizer.step()
|
||||
|
@@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
||||
|
||||
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
|
||||
if isinstance(data, Tensor):
|
||||
ret = data.float()
|
||||
elif isinstance(data, (list, tuple)):
|
||||
ret = [val.float() for val in data]
|
||||
else:
|
||||
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
|
||||
return ret
|
||||
|
||||
|
Reference in New Issue
Block a user