mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
|
||||
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True) -> None:
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
@@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP):
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
if not grad_flag:
|
||||
outputs = self._inference_forward(*args, **kwargs)
|
||||
@@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP):
|
||||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.half()
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
@@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
@@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP):
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||
@@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP):
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
|
||||
scatter_after_inference)
|
||||
super().__init__(module,
|
||||
gemini_manager,
|
||||
pin_memory,
|
||||
force_outputs_fp32,
|
||||
strict_ddp_mode,
|
||||
scatter_after_inference,
|
||||
mixed_precision=mixed_precision)
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
import math
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
@@ -9,7 +8,7 @@ import torch.distributed as dist
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
@@ -22,9 +21,26 @@ __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
module: ZeroDDP,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
self.module = module
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
return self.module.overflow_counter > 0
|
||||
|
||||
def pre_zero_grad(self) -> None:
|
||||
self.module.overflow_counter = 0
|
||||
|
||||
|
||||
class ZeroOptimizer(ColossalaiOptimizer):
|
||||
@@ -79,7 +95,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
@@ -107,15 +122,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
if module.mixed_precision is torch.float16:
|
||||
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif module.mixed_precision is torch.bfloat16:
|
||||
self.mix_precision_mixin = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
@@ -151,15 +171,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
for chunk16 in self.chunk16_set:
|
||||
chunk16.optim_update()
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(self.module.overflow_counter)
|
||||
|
||||
# all-reduce across global group
|
||||
dist.all_reduce(self._found_overflow)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _clear_global_norm(self) -> None:
|
||||
for c16 in self.chunk16_set:
|
||||
c16.l2_norm = None
|
||||
@@ -190,40 +201,25 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
return global_norm
|
||||
|
||||
def _get_combined_scale(self):
|
||||
loss_scale = 1
|
||||
div_scale = self.mix_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
loss_scale = self.loss_scale
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
combined_scale = loss_scale
|
||||
if self.clipping_flag:
|
||||
total_norm = self._calc_global_norm()
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * loss_scale
|
||||
div_scale = clip * div_scale
|
||||
|
||||
if combined_scale == 1:
|
||||
return -1
|
||||
else:
|
||||
return combined_scale
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
return -1 if div_scale == 1.0 else div_scale
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.module.overflow_counter = 0
|
||||
self.mix_precision_mixin.pre_zero_grad()
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
self._set_grad_ptr()
|
||||
|
||||
found_inf = self._check_overflow()
|
||||
if found_inf:
|
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
if self.mix_precision_mixin.should_skip_step():
|
||||
if self.verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
@@ -234,7 +230,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
# get combined scale. combined scale = loss scale * clipping norm
|
||||
# so that gradient = gradient / combined scale
|
||||
combined_scale = self._get_combined_scale()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||
self._register_states()
|
||||
@@ -246,8 +241,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
raise NotImplementedError
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
loss = self.mix_precision_mixin.pre_backward(loss)
|
||||
self.module.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
@@ -255,7 +249,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def _maybe_move_fp32_params(self):
|
||||
|
Reference in New Issue
Block a user