Support 4d parallel + flash attention (#5789)

* support tp + sp + pp

* remove comments

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-06-17 17:40:47 +08:00
committed by GitHub
parent 2ddf624a86
commit 8795bb2e80
5 changed files with 192 additions and 373 deletions

View File

@@ -72,6 +72,7 @@ def main():
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
@@ -174,6 +175,8 @@ def main():
tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero,
sp_size=args.sp,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,