mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -23,6 +23,9 @@ from .dp_plugin_base import DPPluginBase
|
||||
|
||||
__all__ = ['GeminiPlugin']
|
||||
|
||||
SUPPORTED_PRECISION = ['fp16', 'bf16']
|
||||
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
|
||||
|
||||
|
||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
@@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
Args:
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||
@@ -203,6 +207,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
placement_policy: str = "cpu",
|
||||
precision: str = "fp16",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
@@ -223,6 +228,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
|
||||
self.gemini_config = dict(
|
||||
device=(device or get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
@@ -233,6 +239,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
memstats=memstats,
|
||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||
)
|
||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
@@ -253,7 +260,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
return True
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16']
|
||||
return SUPPORTED_PRECISION
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
Reference in New Issue
Block a user