mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -179,7 +179,7 @@ def main():
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="torch_ddp",
|
||||
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
|
||||
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"],
|
||||
help="plugin to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -215,9 +215,9 @@ def main():
|
||||
if args.plugin == "torch_ddp_fp16":
|
||||
booster_kwargs["mixed_precision"] = "fp16"
|
||||
if args.plugin.startswith("torch_ddp"):
|
||||
plugin = TorchDDPPlugin()
|
||||
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
|
||||
elif args.plugin == "gemini":
|
||||
plugin = GeminiPlugin(initial_scale=2**5)
|
||||
plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm)
|
||||
elif args.plugin == "low_level_zero":
|
||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||
elif args.plugin == "hybrid_parallel":
|
||||
@@ -235,6 +235,17 @@ def main():
|
||||
initial_scale=1,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "torch_fsdp":
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
|
||||
from colossalai.booster.plugin import TorchFSDPPlugin
|
||||
|
||||
plugin = TorchFSDPPlugin(
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
|
||||
),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user