[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark
This commit is contained in:
botbw
2024-09-10 17:30:53 +08:00
committed by GitHub
parent 8fd25d6e09
commit c54c4fcd15
21 changed files with 907 additions and 99 deletions

View File

@@ -105,7 +105,7 @@ def main():
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_fp8", action="store_true")
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument(
"--sp_mode",
@@ -151,6 +151,7 @@ def main():
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -164,6 +165,7 @@ def main():
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp":
if use_empty_init:
@@ -224,6 +226,7 @@ def main():
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@@ -241,6 +244,7 @@ def main():
precision="bf16",
overlap_p2p=args.overlap,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")