Merge branch 'main' into dev/zero_bubble

This commit is contained in:
duanjunwen
2024-11-01 03:10:53 +00:00
60 changed files with 1690 additions and 834 deletions

View File

@@ -163,6 +163,8 @@ def main():
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -177,6 +179,8 @@ def main():
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp":
if use_empty_init:
@@ -188,6 +192,7 @@ def main():
),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -209,6 +214,7 @@ def main():
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -219,6 +225,7 @@ def main():
),
cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
if args.pp_style == "zbv":