diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3fbeebcc4..d15245523 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -22,6 +22,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict, custom_policy: Policy) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase): cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. """ def __init__(self, @@ -302,7 +306,8 @@ class HybridParallelPlugin(PipelinePluginBase): zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True) -> None: + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: super().__init__() assert dist.get_world_size() % ( @@ -326,6 +331,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None + self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -405,7 +411,7 @@ class HybridParallelPlugin(PipelinePluginBase): if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config) + self.ddp_config, self.custom_policy) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: