mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
[feat] support no_tp Linear for sharderformer.llama
This commit is contained in:
@@ -163,8 +163,6 @@ def main():
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
@@ -179,8 +177,6 @@ def main():
|
||||
enable_flash_attention=args.xformers,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
@@ -192,7 +188,6 @@ def main():
|
||||
),
|
||||
param_init_fn=empty_init(),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
@@ -214,7 +209,6 @@ def main():
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
param_init_fn=empty_init(),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
plugin = TorchFSDPPlugin(
|
||||
@@ -225,7 +219,6 @@ def main():
|
||||
),
|
||||
cpu_offload=CPUOffload(offload_params=True),
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "3d":
|
||||
if args.pp_style == "zbv":
|
||||
|
Reference in New Issue
Block a user