mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero
This commit is contained in:
@@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
|
||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
|
||||
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True) -> None:
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
@@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP):
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
if not grad_flag:
|
||||
outputs = self._inference_forward(*args, **kwargs)
|
||||
@@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP):
|
||||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.half()
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
@@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
@@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP):
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||
@@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP):
|
||||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
|
||||
scatter_after_inference)
|
||||
super().__init__(module,
|
||||
gemini_manager,
|
||||
pin_memory,
|
||||
force_outputs_fp32,
|
||||
strict_ddp_mode,
|
||||
scatter_after_inference,
|
||||
mixed_precision=mixed_precision)
|
||||
|
Reference in New Issue
Block a user