mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Feature] optimize PP overlap (#5735)
* update to fully overlap, still debugging * improve interface * fixed deadlock bug * debug NaN loss * (experimental) use one comm group for send_fw_recv_fw to fix NaN * cleaned up interfaces; use one batch p2p for all * clean up; removed the double p2p batch case * p2p test passsed * improve overlap: send fwd before backward * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tentatively use 2 p2p batches * remove two p2p batches * fix typos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove pp.sh --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@notebook-c55824c0-7742-45e8-9591-c855bb77ad29-0.notebook-c55824c0-7742-45e8-9591-c855bb77ad29.colossal-ai.svc.cluster.local>
This commit is contained in:
@@ -946,7 +946,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -992,6 +992,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
@@ -1062,7 +1063,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
assert (
|
||||
self.zero_stage <= 1
|
||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=self.pp_axis,
|
||||
@@ -1079,6 +1082,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
enable_metadata_cache=enable_metadata_cache,
|
||||
overlap_p2p=overlap_p2p,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
|
Reference in New Issue
Block a user