[plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel
This commit is contained in:
Hongxin Liu
2024-07-18 15:33:03 +08:00
committed by GitHub
parent 73494de577
commit e86127925a
4 changed files with 42 additions and 12 deletions

View File

@@ -98,6 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@@ -199,9 +200,9 @@ def main():
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":