mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)
* [shardformer] implement policy for all GPT-J models and test * [shardformer] support interleaved pipeline parallel for bert finetune * [shardformer] shardformer support falcon (#4883) * [shardformer]: fix interleaved pipeline for bert model (#5048) * [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093) * Add Mistral support for Shardformer (#5103) * [shardformer] add tests to mistral (#5105) --------- Co-authored-by: Pengtai Xu <henryxu880@gmail.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: eric8607242 <e0928021388@gmail.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
@@ -911,6 +911,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
|
||||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -946,6 +948,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
@@ -972,17 +976,38 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
self.stage_manager = PipelineStageManager(
|
||||
self.pg_mesh,
|
||||
pipeline_axis=PP_AXIS,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
num_model_chunks=num_model_chunks,
|
||||
)
|
||||
|
||||
if pp_style == "interleaved":
|
||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||
self.schedule = InterleavedSchedule(
|
||||
stage_manager=self.stage_manager,
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_microbatch=num_microbatches,
|
||||
microbatch_size=microbatch_size,
|
||||
)
|
||||
elif pp_style == "1f1b":
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
|
Reference in New Issue
Block a user