[feat] support no_tp Linear for sharderformer.llama

This commit is contained in:
duanjunwen
2024-11-05 05:55:42 +00:00
parent 8e40087633
commit 4fc92aa77d
5 changed files with 140 additions and 42 deletions

View File

@@ -163,8 +163,6 @@ def main():
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -179,8 +177,6 @@ def main():
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp":
if use_empty_init:
@@ -192,7 +188,6 @@ def main():
),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -214,7 +209,6 @@ def main():
cpu_offload=CPUOffload(offload_params=True),
param_init_fn=empty_init(),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
else:
plugin = TorchFSDPPlugin(
@@ -225,7 +219,6 @@ def main():
),
cpu_offload=CPUOffload(offload_params=True),
fp8_communication=args.use_fp8_comm,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "3d":
if args.pp_style == "zbv":