mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy
This commit is contained in:
@@ -125,11 +125,12 @@ def train(args):
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=True if args.sp > 1 else False,
|
||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.batch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
@@ -194,7 +195,9 @@ def train(args):
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
tp_size=args.tp,
|
||||
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
|
||||
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
|
||||
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
|
||||
)
|
||||
# print(len(train_dataloader))
|
||||
# for batch in train_dataloader:
|
||||
@@ -321,6 +324,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
|
Reference in New Issue
Block a user