mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -177,6 +177,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ddp_kwargs = dict(
|
||||
@@ -187,6 +188,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
@@ -226,6 +228,11 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
if self.fp8_communication:
|
||||
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
|
||||
|
||||
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
|
Reference in New Issue
Block a user