moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy

This commit is contained in:
YeAnbang
2024-05-28 07:58:08 +00:00
parent 7e65b71815
commit 0b4a33548c
7 changed files with 355 additions and 91 deletions

View File

@@ -56,7 +56,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -64,7 +64,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -89,8 +89,8 @@ def train(args):
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
@@ -180,7 +180,9 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=args.tp,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@@ -300,6 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])