mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
60
colossalai/legacy/amp/naive_amp/__init__.py
Normal file
60
colossalai/legacy/amp/naive_amp/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import inspect
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import ConstantGradScaler, DynamicGradScaler
|
||||
from colossalai.legacy.utils import is_no_pp_or_last_stage
|
||||
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
|
||||
|
||||
|
||||
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
|
||||
"""A helper function to wrap training components with naive AMP modules. In this mode,
|
||||
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
|
||||
which is equivalent to Apex O3.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): your model object
|
||||
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
|
||||
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
|
||||
|
||||
Returns:
|
||||
Tuple: A tuple (model, optimizer)
|
||||
|
||||
The ``amp_config`` should contain parameters below::
|
||||
|
||||
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
|
||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
|
||||
Note that clipping is ignored if clip_grad == 0.
|
||||
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
|
||||
"""
|
||||
if isinstance(model, nn.ModuleList):
|
||||
# interleaved pipeline
|
||||
module_list = []
|
||||
for chunk, m in enumerate(model):
|
||||
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
|
||||
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
|
||||
model = nn.ModuleList(module_list)
|
||||
else:
|
||||
output_to_fp32 = is_no_pp_or_last_stage()
|
||||
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
|
||||
|
||||
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
|
||||
if use_dynamic_grad_scaler:
|
||||
scaler_class = DynamicGradScaler
|
||||
else:
|
||||
scaler_class = ConstantGradScaler
|
||||
|
||||
sig = inspect.signature(scaler_class.__init__)
|
||||
kwargs = dict()
|
||||
for param in sig.parameters.values():
|
||||
if param.name in amp_config:
|
||||
kwargs[param.name] = amp_config.pop(param.name)
|
||||
grad_scaler = scaler_class(**kwargs)
|
||||
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
|
372
colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
Normal file
372
colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
Normal file
@@ -0,0 +1,372 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
from ._utils import has_inf_or_nan, zero_gard_by_list
|
||||
|
||||
try:
|
||||
from colossalai._C import fused_optim
|
||||
except:
|
||||
fused_optim = None
|
||||
|
||||
__all__ = ['FP16Optimizer']
|
||||
|
||||
|
||||
def load_fused_optim():
|
||||
global fused_optim
|
||||
|
||||
if fused_optim is None:
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
|
||||
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||
"""
|
||||
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
|
||||
|
||||
Use multi-tensor-applier to copy values from one list to another.
|
||||
We don't have a blfoat16 implementation so for now if the overflow_buf
|
||||
is not provided, we default back to simple loop copy to be compatible
|
||||
with bfloat16.
|
||||
"""
|
||||
if overflow_buf:
|
||||
overflow_buf.fill_(0)
|
||||
# Scaling with factor `1.0` is equivalent to copy.
|
||||
global fused_optim
|
||||
load_fused_optim()
|
||||
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
else:
|
||||
for this_, that_ in zip(this, that):
|
||||
that_.copy_(this_)
|
||||
|
||||
|
||||
class FP16Optimizer(Optimizer):
|
||||
"""Float16 optimizer for fp16 and bf16 data types.
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
|
||||
grad_scaler (BaseGradScaler): grad scaler for gradient chose in
|
||||
``constant_grad_scaler`` or ``dynamic_grad_scaler``.
|
||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0.
|
||||
Note that clipping is ignored if clip_grad == 0
|
||||
verbose (bool, optional): if set to `True`, will print debug info. Default False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Optimizer,
|
||||
grad_scaler: BaseGradScaler,
|
||||
verbose: bool = False,
|
||||
clip_grad_norm=0,
|
||||
dp_process_group: ProcessGroup = None,
|
||||
mp_process_group: ProcessGroup = None):
|
||||
# have a defaults for compatibility with pytorch optim
|
||||
self._optimizer = optimizer
|
||||
self._defaults = optimizer.defaults
|
||||
|
||||
# fp16-related params
|
||||
assert isinstance(grad_scaler, BaseGradScaler)
|
||||
self._grad_scaler = grad_scaler
|
||||
self._found_overflow = torch.cuda.FloatTensor([0.0])
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
# misc params
|
||||
self._clip_grad_max_norm = clip_grad_norm
|
||||
|
||||
# get process group
|
||||
def _get_process_group(parallel_mode):
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode):
|
||||
return gpc.get_group(parallel_mode)
|
||||
else:
|
||||
return None
|
||||
|
||||
if dp_process_group is None:
|
||||
dp_process_group = _get_process_group(ParallelMode.DATA)
|
||||
if mp_process_group is None:
|
||||
mp_process_group = _get_process_group(ParallelMode.MODEL)
|
||||
|
||||
self._dp_process_group = dp_process_group
|
||||
self._mp_process_group = mp_process_group
|
||||
|
||||
# we maintain three groups of parameters
|
||||
# so that the model can have a mixture
|
||||
# of fp16 and fp32 params
|
||||
# fp16_param_groups: the fp16 params of the model
|
||||
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
|
||||
# fp32_param_groups: the fp32 params of the model
|
||||
# NOTE:
|
||||
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
|
||||
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
|
||||
self._fp16_param_groups = []
|
||||
self._fp32_master_param_groups = []
|
||||
self._fp32_param_groups = []
|
||||
|
||||
# For all the groups in the original optimizer:
|
||||
for param_group in self._optimizer.param_groups:
|
||||
fp16_params = []
|
||||
fp32_master_params = []
|
||||
fp32_params = []
|
||||
# For all the parameters in this group:
|
||||
for i, param in enumerate(param_group['params']):
|
||||
if param.requires_grad:
|
||||
# float16 params:
|
||||
if param.type() in ['torch.cuda.HalfTensor']:
|
||||
fp16_params.append(param)
|
||||
|
||||
# Create a fp32 copy
|
||||
fp32_param = param.detach().clone().float()
|
||||
# Copy tensor model parallel attributes.
|
||||
copy_tensor_parallel_attributes(param, fp32_param)
|
||||
|
||||
# Replace the optimizer params with the new fp32 copy.
|
||||
param_group['params'][i] = fp32_param
|
||||
fp32_master_params.append(fp32_param)
|
||||
|
||||
# Reset existing state dict key to the new main param.
|
||||
if param in self._optimizer.state:
|
||||
self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
|
||||
|
||||
# fp32 params.
|
||||
elif param.type() == 'torch.cuda.FloatTensor':
|
||||
fp32_params.append(param)
|
||||
else:
|
||||
raise TypeError('Expected parameter of type torch.cuda.FloatTensor '
|
||||
f'or torch.cuda.HalfTensor, but got {param.type()}')
|
||||
|
||||
self._fp16_param_groups.append(fp16_params)
|
||||
self._fp32_master_param_groups.append(fp32_master_params)
|
||||
self._fp32_param_groups.append(fp32_params)
|
||||
|
||||
# Leverage state_dict() and load_state_dict() to
|
||||
# recast preexisting per-param state tensors
|
||||
self._optimizer.load_state_dict(self._optimizer.state_dict())
|
||||
|
||||
# log config
|
||||
self._logger = get_dist_logger()
|
||||
if verbose:
|
||||
self._logger.info(
|
||||
f"\n========= FP16 Optimizer Config =========\n"
|
||||
f"Optimizer: {optimizer.__class__.__name__}\n"
|
||||
f"clip_grad_norm = {clip_grad_norm}\n"
|
||||
f"grad_scaler = {self._grad_scaler.__class__.__name__}"
|
||||
f"==========================================",
|
||||
ranks=[0])
|
||||
|
||||
@property
|
||||
def max_norm(self):
|
||||
"""Returns the maximum norm of gradient clipping.
|
||||
"""
|
||||
return self._clip_grad_max_norm
|
||||
|
||||
@property
|
||||
def grad_scaler(self):
|
||||
"""Returns the gradient scaler.
|
||||
|
||||
Returns:
|
||||
:class:`BaseGradScaler`: gradient scaler.
|
||||
"""
|
||||
|
||||
return self._grad_scaler
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
"""Returns the loss scale.
|
||||
|
||||
Returns:
|
||||
int: loss scale.
|
||||
"""
|
||||
return self._grad_scaler.scale
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""Returns the optimizer.
|
||||
|
||||
Returns:
|
||||
:class:`torch.optim.Optimizer`: the optimizer object wrapped.
|
||||
"""
|
||||
return self._optimizer
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
"""Returns the default arguments of optimizer.
|
||||
|
||||
Returns:
|
||||
dict: optimizer arguments saved in defaults of the optimizer wrapped.
|
||||
"""
|
||||
return self._defaults
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group in self._optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is not None and has_inf_or_nan(p.grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
if self._dp_process_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_process_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def zero_grad(self, set_to_none=True):
|
||||
"""Set gradient to zero.
|
||||
|
||||
Args:
|
||||
set_to_none (bool): Whether set the gradient to None.
|
||||
"""
|
||||
|
||||
# set_to_none = True can save some memory space
|
||||
for param_group in self._optimizer.param_groups:
|
||||
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
|
||||
|
||||
def _get_fp32_param_groups_to_update(self):
|
||||
return self._fp32_master_param_groups + self._fp32_param_groups
|
||||
|
||||
def _unscale_grads(self):
|
||||
for group in self._get_fp32_param_groups_to_update():
|
||||
for p in group:
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(self.loss_scale)
|
||||
|
||||
def _assign_grad_to_fp32_master_param(self):
|
||||
# This only needs to be done for the float16 group.
|
||||
for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
|
||||
for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
|
||||
if fp16_param.grad is not None:
|
||||
fp32_param.grad = fp16_param.grad.float()
|
||||
# clear unneeded grad on fp16 param
|
||||
fp16_param.grad = None
|
||||
|
||||
def _update_fp16_param_from_fp32_param(self):
|
||||
fp16_param_data = []
|
||||
fp32_master_param_data = []
|
||||
for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
|
||||
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
||||
fp16_param_data.append(fp16_param.data)
|
||||
fp32_master_param_data.append(fp32_param.data)
|
||||
_multi_tensor_copy_this_to_that(this=fp32_master_param_data,
|
||||
that=fp16_param_data,
|
||||
overflow_buf=self._dummy_overflow_buf)
|
||||
|
||||
def step(self):
|
||||
"""Update the model parameters.
|
||||
"""
|
||||
|
||||
# Copy gradients from model params to main params.
|
||||
self._assign_grad_to_fp32_master_param()
|
||||
self._unscale_grads()
|
||||
|
||||
overflow = self._check_overflow()
|
||||
self._grad_scaler.update(overflow)
|
||||
if overflow:
|
||||
self.zero_grad()
|
||||
|
||||
# Clip the main gradients.
|
||||
grad_norm = None
|
||||
if self._clip_grad_max_norm > 0.0:
|
||||
grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
|
||||
|
||||
if not overflow:
|
||||
# Step the optimizer.
|
||||
self._optimizer.step()
|
||||
|
||||
# Update params from main params.
|
||||
self._update_fp16_param_from_fp32_param()
|
||||
|
||||
# Successful update.
|
||||
return True, grad_norm
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def backward(self, loss):
|
||||
"""Execute backward pass.
|
||||
|
||||
Args:
|
||||
loss (:class:`torch.Tensor`): the loss value.
|
||||
"""
|
||||
|
||||
scaled_loss = loss * self.grad_scaler.scale
|
||||
scaled_loss.backward()
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the states of the fp16 optimizer as a dict object.
|
||||
"""
|
||||
|
||||
state_dict = {}
|
||||
state_dict['optimizer'] = self._optimizer.state_dict()
|
||||
if self.grad_scaler:
|
||||
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
|
||||
state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load the states of the fp16 optimizer from a dict object.
|
||||
|
||||
Args:
|
||||
state_dict (dict): the states of the fp16 optimizer
|
||||
"""
|
||||
|
||||
# Optimizer.
|
||||
self._optimizer.load_state_dict(state_dict['optimizer'])
|
||||
|
||||
# Grad scaler.
|
||||
if 'grad_scaler' in state_dict:
|
||||
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
|
||||
|
||||
# Copy data for the main params.
|
||||
if 'fp32_master_param_groups' in state_dict:
|
||||
for current_group, ckpt_group in zip(self._fp32_master_param_groups,
|
||||
state_dict['fp32_master_param_groups']):
|
||||
for current_param, ckpt_param in zip(current_group, ckpt_group):
|
||||
current_param.data.copy_(ckpt_param.data)
|
||||
|
||||
def clip_grad_norm(self, clip_grad):
|
||||
"""Clip gradients by norm.
|
||||
|
||||
Args:
|
||||
clip_grad (float): the max norm for clipping
|
||||
"""
|
||||
params = []
|
||||
for param_group in self._optimizer.param_groups:
|
||||
for param in param_group['params']:
|
||||
params.append(param)
|
||||
return clip_grad_norm_fp32(params, clip_grad)
|
||||
|
||||
# Promote state so it can be retrieved or set via
|
||||
# "optimizer_instance.state"
|
||||
def _get_state(self):
|
||||
return self._optimizer.state
|
||||
|
||||
def _set_state(self, value):
|
||||
self._optimizer.state = value
|
||||
|
||||
state = property(_get_state, _set_state)
|
||||
|
||||
# Promote param_groups so it can be retrieved or set via
|
||||
# "optimizer_instance.param_groups"
|
||||
# (for example, to adjust the learning rate)
|
||||
def _get_param_groups(self):
|
||||
return self._optimizer.param_groups
|
||||
|
||||
def _set_param_groups(self, value):
|
||||
self._optimizer.param_groups = value
|
||||
|
||||
param_groups = property(_get_param_groups, _set_param_groups)
|
49
colossalai/legacy/amp/naive_amp/_utils.py
Normal file
49
colossalai/legacy/amp/naive_amp/_utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def has_inf_or_nan(tensor):
|
||||
"""Check if tensor has inf or nan values.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): a torch tensor object
|
||||
|
||||
Returns:
|
||||
bool: Whether the tensor has inf or nan. True for yes and False for no.
|
||||
"""
|
||||
try:
|
||||
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
||||
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
||||
# (which is true for some recent version of pytorch).
|
||||
tensor_sum = float(tensor.float().sum())
|
||||
# More efficient version that can be used if .sum() returns a Python scalar
|
||||
# tensor_sum = float(tensor.sum())
|
||||
except RuntimeError as instance:
|
||||
# We want to check if inst is actually an overflow exception.
|
||||
# RuntimeError could come from a different error.
|
||||
# If so, we still want the exception to propagate.
|
||||
if "value cannot be converted" not in instance.args[0]:
|
||||
raise
|
||||
return True
|
||||
else:
|
||||
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
|
||||
"""Clear the gradient of a list of tensors,
|
||||
|
||||
Note: copied from torch.optim.optimizer.
|
||||
"""
|
||||
for param in tensor_list:
|
||||
if param.grad is not None:
|
||||
if set_to_none:
|
||||
param.grad = None
|
||||
else:
|
||||
if param.grad.grad_fn is not None:
|
||||
param.grad.detach_()
|
||||
else:
|
||||
param.grad.requires_grad_(False)
|
||||
param.grad.zero_()
|
161
colossalai/legacy/amp/naive_amp/naive_amp.py
Normal file
161
colossalai/legacy/amp/naive_amp/naive_amp.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ReduceOp
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
from ._fp16_optimizer import FP16Optimizer
|
||||
|
||||
|
||||
class NaiveAMPOptimizer(OptimizerWrapper):
|
||||
"""A wrapper class for optimizer to cast all parameters to fp16
|
||||
|
||||
Args:
|
||||
optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD.
|
||||
grad_scaler (BaseGradScaler): grad scaler for gradient chose in
|
||||
``constant_grad_scaler`` or ``dynamic_grad_scaler``.
|
||||
clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0.
|
||||
verbose (bool, optional): if set to `True`, will print debug info. Default False.
|
||||
|
||||
Note:
|
||||
clipping is ignored if ``clip_grad_norm`` equals 0.
|
||||
"""
|
||||
|
||||
def __init__(self, optim: Optimizer, *args, **kwargs):
|
||||
optim = FP16Optimizer(optim, *args, **kwargs)
|
||||
super().__init__(optim)
|
||||
|
||||
def backward(self, loss: Tensor):
|
||||
self.optim.backward(loss)
|
||||
|
||||
def step(self):
|
||||
return self.optim.step()
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if self.optim.max_norm == max_norm:
|
||||
return
|
||||
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). "
|
||||
"If you have supplied clip_grad_norm in the amp_config, "
|
||||
"executing the method clip_grad_norm is not allowed.")
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
r"""A wrapper class for model to cast the model into fp16 and
|
||||
automatically cast the input and output
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): torch.nn.Module to be wrapped.
|
||||
output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True)
|
||||
parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this module.
|
||||
(Default: ``ParallelMode.DATA``)
|
||||
sync_buffer (bool, optional): whether to synchronize buffer. (Default: True)
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
output_to_fp32: bool = True,
|
||||
parallel_mode: ParallelMode = ParallelMode.DATA,
|
||||
sync_buffer: bool = True):
|
||||
super().__init__()
|
||||
self.model = model.half()
|
||||
self._output_to_fp32 = output_to_fp32
|
||||
self._sync_buf = sync_buffer
|
||||
|
||||
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
||||
self._process_group = gpc.get_group(parallel_mode)
|
||||
self._world_size = gpc.get_world_size(parallel_mode)
|
||||
else:
|
||||
self._process_group = None
|
||||
self._world_size = 1
|
||||
self._sync_buf = False
|
||||
self._first_eval_run = False
|
||||
|
||||
@property
|
||||
def sync_buffer(self):
|
||||
return self._sync_buf
|
||||
|
||||
@sync_buffer.setter
|
||||
def sync_buffer(self, state: bool):
|
||||
self._sync_buf = state
|
||||
|
||||
def _convert_to_fp16(self, input_: Any):
|
||||
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
|
||||
input_ = input_.half()
|
||||
return input_
|
||||
|
||||
def _convert_to_fp32(self, input_: Any):
|
||||
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
|
||||
input_ = input_.float()
|
||||
return input_
|
||||
|
||||
def _reduce_module_buffer(self):
|
||||
"""
|
||||
All-reduce the buffers (e.g. running stats of batch normalization) across
|
||||
data parallel ranks so that all the ranks will produce consistent results
|
||||
when given the same input
|
||||
"""
|
||||
buf_list = []
|
||||
|
||||
# find valid buffers
|
||||
for buf in self.model.buffers():
|
||||
if buf is not None:
|
||||
buf_list.append(buf)
|
||||
|
||||
# reduce buffers across data parallel ranks
|
||||
if buf_list:
|
||||
coalesced_buf = _flatten_dense_tensors(buf_list)
|
||||
coalesced_buf.div_(self._world_size)
|
||||
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
|
||||
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
|
||||
for old, new in zip(buf_list, unflattened_buf_list):
|
||||
old.copy_(new)
|
||||
|
||||
def eval(self):
|
||||
self.model.eval()
|
||||
|
||||
# we only sync buffer in the first eval iteration
|
||||
# so that future eval iterations can be done without communication
|
||||
self._first_eval_run = True
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# reduce buffers after forward will lead to error
|
||||
# as we cannot change the variables needed for gradient computation after forward
|
||||
# so we sync buffer before forward
|
||||
if (self.training or self._first_eval_run) and self._sync_buf:
|
||||
with torch.no_grad():
|
||||
self._reduce_module_buffer()
|
||||
|
||||
if self._first_eval_run:
|
||||
self._first_eval_run = False
|
||||
|
||||
if args:
|
||||
args = [self._convert_to_fp16(arg) for arg in args]
|
||||
if kwargs:
|
||||
for k, v in kwargs.items():
|
||||
kwargs[k] = self._convert_to_fp16(v)
|
||||
|
||||
out = self.model(*args, **kwargs)
|
||||
|
||||
if self._output_to_fp32:
|
||||
if isinstance(out, Tensor):
|
||||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
return out
|
Reference in New Issue
Block a user