[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)

* support DDP for HybridPlugin/add tp+dp tests

* add docstring for HybridParallelPlugin
This commit is contained in:
Baizhou Zhang
2023-08-16 16:11:57 +08:00
committed by GitHub
parent 424629fea0
commit 6ef33f75aa
10 changed files with 199 additions and 100 deletions

View File

@@ -6,7 +6,8 @@ import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
@@ -28,7 +29,8 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
shardformer = ShardFormer(shard_config)
@@ -45,7 +47,15 @@ class HybridParallelModule(ModelWrapper):
module = module.to(dtype=torch.bfloat16).cuda()
else:
module = module.cuda() # train without AMP
# TODO(ver217): support TP+DP
if use_ddp:
# convert model to sync bn
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
# wrap the model with PyTorch DDP
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
def sync_shared_params(self):
@@ -68,6 +78,12 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size())
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
@@ -140,29 +156,81 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class HybridParallelPlugin(PipelinePluginBase):
"""
Plugin for Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True.
bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False.
"""
def __init__(self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers=True,
bucket_cap_mb=25,
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
static_graph=False) -> None:
def __init__(
self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
) -> None:
super().__init__()
assert dist.get_world_size() % (
tp_size * pp_size
@@ -208,6 +276,13 @@ class HybridParallelPlugin(PipelinePluginBase):
min_scale=min_scale,
max_scale=max_scale,
)
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
self.max_norm = max_norm
@property
@@ -241,7 +316,9 @@ class HybridParallelPlugin(PipelinePluginBase):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']: