mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark
This commit is contained in:
@@ -105,7 +105,7 @@ def main():
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||
parser.add_argument("--use_fp8", action="store_true")
|
||||
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
parser.add_argument(
|
||||
"--sp_mode",
|
||||
@@ -151,6 +151,7 @@ def main():
|
||||
max_prefetch=args.prefetch_num,
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
@@ -164,6 +165,7 @@ def main():
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
enable_flash_attention=args.xformers,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
@@ -224,6 +226,7 @@ def main():
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
@@ -241,6 +244,7 @@ def main():
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
Reference in New Issue
Block a user