[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hanks
2024-08-08 15:55:01 +08:00
committed by GitHub
parent 7739629b9d
commit b480eec738
14 changed files with 602 additions and 14 deletions

View File

@@ -364,6 +364,7 @@ class GeminiPlugin(DPPluginBase):
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
@@ -395,6 +396,7 @@ class GeminiPlugin(DPPluginBase):
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,

View File

@@ -177,6 +177,7 @@ class TorchDDPPlugin(DPPluginBase):
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(
@@ -187,6 +188,7 @@ class TorchDDPPlugin(DPPluginBase):
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.fp8_communication = fp8_communication
def support_no_sync(self) -> bool:
return True
@@ -226,6 +228,11 @@ class TorchDDPPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
if self.fp8_communication:
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:

View File

@@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase):
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
fp8_communication: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(
@@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase):
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
self.fp8_communication = fp8_communication
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@@ -347,6 +349,19 @@ class TorchFSDPPlugin(DPPluginBase):
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if self.fp8_communication:
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
patch_fsdp_params_comm_hook()
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(