[bf16] add bf16 support (#3882)

* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
This commit is contained in:
Hongxin Liu
2023-06-05 15:58:31 +08:00
committed by GitHub
parent 07cb21142f
commit ae02d4e4f7
27 changed files with 738 additions and 525 deletions

View File

@@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""
def __init__(self,
@@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True) -> None:
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
@@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self._logger = get_dist_logger()
@@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP):
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
self.module.zero_grad(set_to_none=True)
if not grad_flag:
outputs = self._inference_forward(*args, **kwargs)
@@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP):
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
@@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP):
buffer.materialize()
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
buffer.data = buffer.to(self.mixed_precision)
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
@@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP):
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
@@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP):
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
scatter_after_inference)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision)