mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -23,6 +23,9 @@ from .dp_plugin_base import DPPluginBase
|
||||
|
||||
__all__ = ['GeminiPlugin']
|
||||
|
||||
SUPPORTED_PRECISION = ['fp16', 'bf16']
|
||||
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
|
||||
|
||||
|
||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
@@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
Args:
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||
@@ -203,6 +207,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
placement_policy: str = "cpu",
|
||||
precision: str = "fp16",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
@@ -223,6 +228,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
|
||||
self.gemini_config = dict(
|
||||
device=(device or get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
@@ -233,6 +239,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
memstats=memstats,
|
||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||
)
|
||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
@@ -253,7 +260,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
return True
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16']
|
||||
return SUPPORTED_PRECISION
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -20,12 +21,15 @@ from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
__all__ = ['LowLevelZeroPlugin']
|
||||
|
||||
|
||||
def _convert_to_fp16(x):
|
||||
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
||||
return x.half()
|
||||
return x.to(dtype)
|
||||
return x
|
||||
|
||||
|
||||
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
@@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
self.convert_inputs = (precision == 'fp16')
|
||||
module = zero_model_wrapper(module, zero_stage=stage)
|
||||
self.dtype = None
|
||||
if precision == 'fp16':
|
||||
module = module.half()
|
||||
self.dtype = torch.float16
|
||||
elif precision == 'bf16':
|
||||
self.dtype = torch.bfloat16
|
||||
module = zero_model_wrapper(module, zero_stage=stage)
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_inputs:
|
||||
args = tree_map(_convert_to_fp16, args)
|
||||
kwargs = tree_map(_convert_to_fp16, kwargs)
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
||||
Args:
|
||||
strage (int, optional): ZeRO stage. Defaults to 1.
|
||||
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
|
||||
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||
@@ -149,7 +160,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
||||
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
|
||||
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
||||
|
||||
self.stage = stage
|
||||
self.precision = precision
|
||||
@@ -175,7 +186,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
return True
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'fp32']
|
||||
return SUPPORTED_PRECISION
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
Reference in New Issue
Block a user