mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -364,6 +364,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_sequence_overlap: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
verbose: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
@@ -395,6 +396,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
master_weights=master_weights,
|
||||
max_prefetch=max_prefetch,
|
||||
enable_async_reduce=enable_async_reduce,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
@@ -177,6 +177,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ddp_kwargs = dict(
|
||||
@@ -187,6 +188,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
@@ -226,6 +228,11 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = OptimizerWrapper(optimizer)
|
||||
|
||||
if self.fp8_communication:
|
||||
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
|
||||
|
||||
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
|
@@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
|
||||
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
|
||||
sync_module_states: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(
|
||||
@@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
@@ -347,6 +349,19 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if self.fp8_communication:
|
||||
from colossalai.quantization.utils import patch_fsdp_params_comm_hook
|
||||
|
||||
patch_fsdp_params_comm_hook()
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook
|
||||
|
||||
fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
|
||||
|
||||
from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook
|
||||
|
||||
fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
|
||||
|
||||
if optimizer is not None:
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
|
Reference in New Issue
Block a user