mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
|
||||
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True) -> None:
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
@@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP):
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
if not grad_flag:
|
||||
outputs = self._inference_forward(*args, **kwargs)
|
||||
@@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP):
|
||||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.half()
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
@@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
@@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP):
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||
@@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP):
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
|
||||
scatter_after_inference)
|
||||
super().__init__(module,
|
||||
gemini_manager,
|
||||
pin_memory,
|
||||
force_outputs_fp32,
|
||||
strict_ddp_mode,
|
||||
scatter_after_inference,
|
||||
mixed_precision=mixed_precision)
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
import math
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
@@ -9,7 +8,7 @@ import torch.distributed as dist
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
@@ -22,9 +21,26 @@ __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
|
||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
module: ZeroDDP,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
self.module = module
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
return self.module.overflow_counter > 0
|
||||
|
||||
def pre_zero_grad(self) -> None:
|
||||
self.module.overflow_counter = 0
|
||||
|
||||
|
||||
class ZeroOptimizer(ColossalaiOptimizer):
|
||||
@@ -79,7 +95,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
@@ -107,15 +122,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
if module.mixed_precision is torch.float16:
|
||||
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif module.mixed_precision is torch.bfloat16:
|
||||
self.mix_precision_mixin = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
@@ -151,15 +171,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
for chunk16 in self.chunk16_set:
|
||||
chunk16.optim_update()
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(self.module.overflow_counter)
|
||||
|
||||
# all-reduce across global group
|
||||
dist.all_reduce(self._found_overflow)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _clear_global_norm(self) -> None:
|
||||
for c16 in self.chunk16_set:
|
||||
c16.l2_norm = None
|
||||
@@ -190,40 +201,25 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
return global_norm
|
||||
|
||||
def _get_combined_scale(self):
|
||||
loss_scale = 1
|
||||
div_scale = self.mix_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
loss_scale = self.loss_scale
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
combined_scale = loss_scale
|
||||
if self.clipping_flag:
|
||||
total_norm = self._calc_global_norm()
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * loss_scale
|
||||
div_scale = clip * div_scale
|
||||
|
||||
if combined_scale == 1:
|
||||
return -1
|
||||
else:
|
||||
return combined_scale
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
return -1 if div_scale == 1.0 else div_scale
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.module.overflow_counter = 0
|
||||
self.mix_precision_mixin.pre_zero_grad()
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
self._set_grad_ptr()
|
||||
|
||||
found_inf = self._check_overflow()
|
||||
if found_inf:
|
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
if self.mix_precision_mixin.should_skip_step():
|
||||
if self.verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
@@ -234,7 +230,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
# get combined scale. combined scale = loss scale * clipping norm
|
||||
# so that gradient = gradient / combined scale
|
||||
combined_scale = self._get_combined_scale()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||
self._register_states()
|
||||
@@ -246,8 +241,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
raise NotImplementedError
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
loss = self.mix_precision_mixin.pre_backward(loss)
|
||||
self.module.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
@@ -255,7 +249,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def _maybe_move_fp32_params(self):
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_param import ShardedParamV2
|
||||
|
||||
@@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
seed (int, optional): Random seed for weight initialization
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
||||
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
"""
|
||||
|
||||
@@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
default_dtype: Optional[torch.dtype] = None,
|
||||
bf16: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||
|
||||
super().__init__(default_dtype=default_dtype)
|
||||
@@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
self.param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.seed = seed
|
||||
self.bf16 = bf16
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
||||
@@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
self.top_module = module
|
||||
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
return t.to(half_dtype) if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
@@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
# We must cast them
|
||||
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
|
||||
for buffer in module.buffers(recurse=False):
|
||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
||||
buffer.data = cast_fn(buffer.data)
|
||||
|
||||
|
||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
|
@@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
||||
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
|
||||
return tensor.float()
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||
return tensor.bfloat16()
|
||||
return tensor
|
||||
|
||||
|
||||
def apply_to_tensors(x: Any, fn: Callable):
|
||||
if torch.is_tensor(x):
|
||||
return fn(x)
|
||||
|
@@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc
|
||||
|
||||
from ._utils import (
|
||||
cast_float_arguments,
|
||||
cast_tensor_to_bf16,
|
||||
cast_tensor_to_fp16,
|
||||
cast_tensor_to_fp32,
|
||||
chunk_and_pad,
|
||||
@@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
|
||||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||
We find that PyTorch's optimizers don't support mixed precision,
|
||||
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
||||
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module):
|
||||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
bf16: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
self.bf16 = bf16
|
||||
|
||||
# We force users to use ZeroInitContext
|
||||
for submodule in module.modules():
|
||||
@@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
self._pre_forward_operations(*args)
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
|
||||
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self._post_forward_operations()
|
||||
return outputs
|
||||
|
@@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
super().__init__(optimizer)
|
||||
self.shard_strategy = sharded_model.shard_strategy
|
||||
self.model: ShardedModelV2 = sharded_model
|
||||
self.bf16 = sharded_model.bf16
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
@@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
self._verbose = verbose
|
||||
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
|
||||
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
@@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
self._zero_grad()
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
if not self.bf16:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||
@@ -175,30 +179,33 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
if not self.bf16:
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward_by_grad(tensor, grad)
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._prepare_grads()
|
||||
self._prepare_grads()
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
|
||||
self._prepare_grads()
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._prepare_grads()
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
|
||||
self._maybe_move_fp32_shards()
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
if not self.bf16:
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
|
||||
self._point_param_fp16_to_master_param()
|
||||
|
||||
@@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
state[k] = v.cuda()
|
||||
|
||||
def _prepare_grads(self):
|
||||
if self._grad_prepared:
|
||||
return
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
@@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
p.grad = p.colo_attr.grad_payload
|
||||
# Set p.data to empty tensor, in case of memory leaking
|
||||
p.colo_attr.set_data_none()
|
||||
self._grad_prepared = True
|
||||
|
||||
def _point_param_fp16_to_master_param(self):
|
||||
# assign master param pointers to p.data.
|
||||
@@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
|
||||
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
|
@@ -6,7 +6,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
BF16MixedPrecisionMixin,
|
||||
FP16MixedPrecisionMixin,
|
||||
MixedPrecisionMixin,
|
||||
)
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -27,6 +31,31 @@ from ._utils import (
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
num_working_param_groups: int,
|
||||
grad_store: GradientStore,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
self.num_working_param_groups = num_working_param_groups
|
||||
self.grad_store = grad_store
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
for group_id in range(self.num_working_param_groups):
|
||||
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
@@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
verbose=verbose)
|
||||
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
@@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
if self._overlap_communication or self._partition_grads:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
# initialize mixed precision mixin
|
||||
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
|
||||
if self._dtype is torch.float16:
|
||||
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
|
||||
self._grad_store,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif self._dtype is torch.bfloat16:
|
||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale
|
||||
|
||||
@property
|
||||
def num_param_groups(self):
|
||||
return len(self._working_param_groups)
|
||||
@@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False, sync_grad=True):
|
||||
loss = self.loss_scale * loss
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
# finish gradient reduction
|
||||
@@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||
:type set_to_none: bool
|
||||
"""
|
||||
if self.mixed_precision_mixin is not None:
|
||||
self.mixed_precision_mixin.pre_zero_grad()
|
||||
for _, param_group in self._working_param_groups.items():
|
||||
for param in param_group:
|
||||
if set_to_none:
|
||||
@@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
self._grad_store.reset_all_average_gradients()
|
||||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
@@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
# Mixed Precision Utilities #
|
||||
#############################
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group_id in range(len(self._working_param_groups)):
|
||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_torch_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
|
||||
|
||||
if self._found_overflow.item() > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = self.loss_scale
|
||||
div_scale = 1.0
|
||||
if self.mixed_precision_mixin is not None:
|
||||
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self._clip_grad_norm > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
div_scale = clip * div_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1. / combined_scale)
|
||||
grad.data.mul_(1. / div_scale)
|
||||
|
||||
############################
|
||||
# Gradient Synchronization #
|
||||
|
Reference in New Issue
Block a user