[shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)

* [shardformer] implement policy for all GPT-J models and test

* [shardformer] support interleaved pipeline parallel for bert finetune

* [shardformer] shardformer support falcon (#4883)

* [shardformer]: fix interleaved pipeline for bert model (#5048)

* [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093)

* Add Mistral support for Shardformer (#5103)

* [shardformer] add tests to mistral (#5105)

---------

Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
This commit is contained in:
Wenhao Chen
2023-11-28 16:54:42 +08:00
committed by GitHub
parent 126cf180bc
commit 7172459e74
39 changed files with 4007 additions and 234 deletions

View File

@@ -22,7 +22,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
@@ -911,6 +911,8 @@ class HybridParallelPlugin(PipelinePluginBase):
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
"""
def __init__(
@@ -946,6 +948,8 @@ class HybridParallelPlugin(PipelinePluginBase):
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
) -> None:
super().__init__()
assert (
@@ -972,17 +976,38 @@ class HybridParallelPlugin(PipelinePluginBase):
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=PP_AXIS,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
)
if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
else:
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,