mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)
* test: add more p2p tests * fix: remove send_forward_recv_forward as p2p op list need to use the same group * fix: make send and receive atomic * feat: update P2PComm fn * feat: add metadata cache in 1f1b * feat: add metadata cache in interleaved pp * feat: modify is_xx_stage fn * revert: add _broadcast_object_list * feat: add interleaved pp in llama policy * feat: set NCCL_BUFFSIZE in HybridParallelPlugin
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import ctypes
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
@@ -21,7 +22,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
@@ -982,6 +984,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
if os.getenv("NCCL_BUFFSIZE") is None:
|
||||
logger = get_dist_logger()
|
||||
logger.warning(
|
||||
"Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen."
|
||||
)
|
||||
os.environ["NCCL_BUFFSIZE"] = "134217728"
|
||||
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
|
Reference in New Issue
Block a user