[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hanks
2024-08-08 15:55:01 +08:00
committed by GitHub
parent 7739629b9d
commit b480eec738
14 changed files with 602 additions and 14 deletions

View File

@@ -179,7 +179,7 @@ def main():
"--plugin",
type=str,
default="torch_ddp",
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"],
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"],
help="plugin to use",
)
parser.add_argument(
@@ -215,9 +215,9 @@ def main():
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5)
plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm)
elif args.plugin == "low_level_zero":
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == "hybrid_parallel":
@@ -235,6 +235,17 @@ def main():
initial_scale=1,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "torch_fsdp":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from colossalai.booster.plugin import TorchFSDPPlugin
plugin = TorchFSDPPlugin(
mixed_precision=MixedPrecision(
param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16
),
fp8_communication=args.use_fp8_comm,
)
booster = Booster(plugin=plugin, **booster_kwargs)

View File

@@ -212,7 +212,7 @@ def main():
if args.plugin == "torch_ddp_fp16":
booster_kwargs["mixed_precision"] = "fp16"
if args.plugin.startswith("torch_ddp"):
plugin = TorchDDPPlugin()
plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm)
elif args.plugin == "gemini":
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == "low_level_zero":

View File

@@ -98,7 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
args = parser.parse_args()
colossalai.launch_from_torch()
@@ -158,6 +158,7 @@ def main():
buffer_dtype=torch.float16,
),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -165,7 +166,8 @@ def main():
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
),
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp_cpu":
if use_empty_init:
@@ -177,6 +179,7 @@ def main():
),
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -186,6 +189,7 @@ def main():
buffer_dtype=torch.float16,
),
cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
@@ -200,9 +204,9 @@ def main():
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@@ -293,7 +297,7 @@ def main():
with get_profile_context(
args.profile,
args.ignore_steps,
1, # avoid creating massive log files
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: