mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -6,7 +6,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
BF16MixedPrecisionMixin,
|
||||
FP16MixedPrecisionMixin,
|
||||
MixedPrecisionMixin,
|
||||
)
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -27,6 +31,31 @@ from ._utils import (
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
num_working_param_groups: int,
|
||||
grad_store: GradientStore,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
self.num_working_param_groups = num_working_param_groups
|
||||
self.grad_store = grad_store
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
for group_id in range(self.num_working_param_groups):
|
||||
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
@@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
verbose=verbose)
|
||||
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
@@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
if self._overlap_communication or self._partition_grads:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
# initialize mixed precision mixin
|
||||
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
|
||||
if self._dtype is torch.float16:
|
||||
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
|
||||
self._grad_store,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif self._dtype is torch.bfloat16:
|
||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale
|
||||
|
||||
@property
|
||||
def num_param_groups(self):
|
||||
return len(self._working_param_groups)
|
||||
@@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False, sync_grad=True):
|
||||
loss = self.loss_scale * loss
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
# finish gradient reduction
|
||||
@@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||
:type set_to_none: bool
|
||||
"""
|
||||
if self.mixed_precision_mixin is not None:
|
||||
self.mixed_precision_mixin.pre_zero_grad()
|
||||
for _, param_group in self._working_param_groups.items():
|
||||
for param in param_group:
|
||||
if set_to_none:
|
||||
@@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
self._grad_store.reset_all_average_gradients()
|
||||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
@@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
# Mixed Precision Utilities #
|
||||
#############################
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group_id in range(len(self._working_param_groups)):
|
||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_torch_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
|
||||
|
||||
if self._found_overflow.item() > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = self.loss_scale
|
||||
div_scale = 1.0
|
||||
if self.mixed_precision_mixin is not None:
|
||||
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self._clip_grad_norm > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
div_scale = clip * div_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1. / combined_scale)
|
||||
grad.data.mul_(1. / div_scale)
|
||||
|
||||
############################
|
||||
# Gradient Synchronization #
|
||||
|
Reference in New Issue
Block a user