mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_param import ShardedParamV2
|
||||
|
||||
@@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
seed (int, optional): Random seed for weight initialization
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
||||
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
"""
|
||||
|
||||
@@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
default_dtype: Optional[torch.dtype] = None,
|
||||
bf16: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||
|
||||
super().__init__(default_dtype=default_dtype)
|
||||
@@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
self.param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.seed = seed
|
||||
self.bf16 = bf16
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
||||
@@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
self.top_module = module
|
||||
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
return t.to(half_dtype) if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
@@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
# We must cast them
|
||||
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
|
||||
for buffer in module.buffers(recurse=False):
|
||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
||||
buffer.data = cast_fn(buffer.data)
|
||||
|
||||
|
||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
|
@@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
||||
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
|
||||
return tensor.float()
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||
return tensor.bfloat16()
|
||||
return tensor
|
||||
|
||||
|
||||
def apply_to_tensors(x: Any, fn: Callable):
|
||||
if torch.is_tensor(x):
|
||||
return fn(x)
|
||||
|
@@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc
|
||||
|
||||
from ._utils import (
|
||||
cast_float_arguments,
|
||||
cast_tensor_to_bf16,
|
||||
cast_tensor_to_fp16,
|
||||
cast_tensor_to_fp32,
|
||||
chunk_and_pad,
|
||||
@@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
|
||||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||
We find that PyTorch's optimizers don't support mixed precision,
|
||||
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
||||
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module):
|
||||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
bf16: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
self.bf16 = bf16
|
||||
|
||||
# We force users to use ZeroInitContext
|
||||
for submodule in module.modules():
|
||||
@@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module):
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
self._pre_forward_operations(*args)
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
|
||||
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self._post_forward_operations()
|
||||
return outputs
|
||||
|
@@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
super().__init__(optimizer)
|
||||
self.shard_strategy = sharded_model.shard_strategy
|
||||
self.model: ShardedModelV2 = sharded_model
|
||||
self.bf16 = sharded_model.bf16
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
@@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
self._verbose = verbose
|
||||
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
|
||||
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
@@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
self._zero_grad()
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
if not self.bf16:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||
@@ -175,30 +179,33 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
if not self.bf16:
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward_by_grad(tensor, grad)
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._prepare_grads()
|
||||
self._prepare_grads()
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
|
||||
self._prepare_grads()
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._prepare_grads()
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
|
||||
self._maybe_move_fp32_shards()
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
if not self.bf16:
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
|
||||
self._point_param_fp16_to_master_param()
|
||||
|
||||
@@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
state[k] = v.cuda()
|
||||
|
||||
def _prepare_grads(self):
|
||||
if self._grad_prepared:
|
||||
return
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
@@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
p.grad = p.colo_attr.grad_payload
|
||||
# Set p.data to empty tensor, in case of memory leaking
|
||||
p.colo_attr.set_data_none()
|
||||
self._grad_prepared = True
|
||||
|
||||
def _point_param_fp16_to_master_param(self):
|
||||
# assign master param pointers to p.data.
|
||||
@@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
|
||||
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
|
Reference in New Issue
Block a user