[bf16] add bf16 support (#3882)

* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
This commit is contained in:
Hongxin Liu
2023-06-05 15:58:31 +08:00
committed by GitHub
parent 07cb21142f
commit ae02d4e4f7
27 changed files with 738 additions and 525 deletions

View File

@@ -6,7 +6,11 @@ import torch
import torch.distributed as dist
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
@@ -27,6 +31,31 @@ from ._utils import (
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
num_working_param_groups: int,
grad_store: GradientStore,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups):
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False
class LowLevelZeroOptimizer(ColossalaiOptimizer):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
@@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
verbose=verbose)
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
self._grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
@property
def dtype(self):
return self._dtype
@property
def loss_scale(self):
return self.grad_scaler.scale
@property
def num_param_groups(self):
return len(self._working_param_groups)
@@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
################################
def backward(self, loss, retain_graph=False, sync_grad=True):
loss = self.loss_scale * loss
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
# finish gradient reduction
@@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
if self.mixed_precision_mixin is not None:
self.mixed_precision_mixin.pre_zero_grad()
for _, param_group in self._working_param_groups.items():
for param in param_group:
if set_to_none:
@@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
# check for overflow
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_average_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
@@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# Mixed Precision Utilities #
#############################
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group_id in range(len(self._working_param_groups)):
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
# all-reduce over model parallel group
if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
if self._found_overflow.item() > 0:
return True
else:
return False
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
div_scale = 1.0
if self.mixed_precision_mixin is not None:
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
if self._clip_grad_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
combined_scale = clip * self.loss_scale
div_scale = clip * div_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
grad.data.mul_(1. / div_scale)
############################
# Gradient Synchronization #