[bf16] add bf16 support (#3882)

* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
This commit is contained in:
Hongxin Liu
2023-06-05 15:58:31 +08:00
committed by GitHub
parent 07cb21142f
commit ae02d4e4f7
27 changed files with 738 additions and 525 deletions

View File

@@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy.sharded_param import ShardedParamV2
@@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
@@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed: int = 2**10 - 1,
shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
bf16: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
super().__init__(default_dtype=default_dtype)
@@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.bf16 = bf16
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
@@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
return t.to(half_dtype) if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
@@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_tensor_to_fp16(buffer.data)
buffer.data = cast_fn(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):

View File

@@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
return tensor.float()
return tensor
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.bfloat16()
return tensor
def apply_to_tensors(x: Any, fn: Callable):
if torch.is_tensor(x):
return fn(x)

View File

@@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc
from ._utils import (
cast_float_arguments,
cast_tensor_to_bf16,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
@@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
"""
def __init__(self,
@@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()
self.bf16 = bf16
# We force users to use ZeroInitContext
for submodule in module.modules():
@@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module):
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations(*args)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
outputs = self.module(*args, **kwargs)
self._post_forward_operations()
return outputs

View File

@@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
self.bf16 = sharded_model.bf16
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
@@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
self._verbose = verbose
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
# Store fp32 param shards
self._register_master_weight()
@@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._zero_grad()
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
if not self.bf16:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
@@ -175,30 +179,33 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
if not self.bf16:
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
self._maybe_move_fp32_shards()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if not self.bf16:
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._zero_grad(recover_data=True)
return
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._zero_grad(recover_data=True)
return
self._point_param_fp16_to_master_param()
@@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
state[k] = v.cuda()
def _prepare_grads(self):
if self._grad_prepared:
return
for group in self.optim.param_groups:
for p in group['params']:
if p.colo_attr.saved_grad.is_null():
@@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.set_data_none()
self._grad_prepared = True
def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
@@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: