mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)
* support DDP for HybridPlugin/add tp+dp tests * add docstring for HybridParallelPlugin
This commit is contained in:
@@ -6,7 +6,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.nn import Module, SyncBatchNorm
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -28,7 +29,8 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
class HybridParallelModule(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
|
||||
ddp_config: dict) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.dp_group = dp_group
|
||||
shardformer = ShardFormer(shard_config)
|
||||
@@ -45,7 +47,15 @@ class HybridParallelModule(ModelWrapper):
|
||||
module = module.to(dtype=torch.bfloat16).cuda()
|
||||
else:
|
||||
module = module.cuda() # train without AMP
|
||||
# TODO(ver217): support TP+DP
|
||||
|
||||
if use_ddp:
|
||||
|
||||
# convert model to sync bn
|
||||
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
|
||||
|
||||
# wrap the model with PyTorch DDP
|
||||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
|
||||
def sync_shared_params(self):
|
||||
@@ -68,6 +78,12 @@ class HybridParallelModule(ModelWrapper):
|
||||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
p.grad.div_(self.dp_group.size())
|
||||
|
||||
def unwrap(self):
|
||||
module = super().unwrap()
|
||||
if isinstance(module, DDP):
|
||||
module = module.module
|
||||
return module
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
params = set(model.parameters())
|
||||
@@ -140,29 +156,81 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
|
||||
|
||||
class HybridParallelPlugin(PipelinePluginBase):
|
||||
"""
|
||||
Plugin for Hybrid Parallel Training.
|
||||
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
||||
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
||||
|
||||
Example:
|
||||
>>> from colossalai.booster import Booster
|
||||
>>> from colossalai.booster.plugin import HybridParallelPlugin
|
||||
|
||||
>>> model, train_dataset, optimizer, criterion = ...
|
||||
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
|
||||
|
||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
>>> booster = Booster(plugin=plugin)
|
||||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
Args:
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
|
||||
When set to 0, ZeRO will not be used. Defaults to 0.
|
||||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
|
||||
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
|
||||
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
|
||||
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
|
||||
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
|
||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True.
|
||||
bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25.
|
||||
find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False.
|
||||
check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
cpu_offload: bool = False,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers=True,
|
||||
bucket_cap_mb=25,
|
||||
find_unused_parameters=False,
|
||||
check_reduction=False,
|
||||
gradient_as_bucket_view=False,
|
||||
static_graph=False) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
cpu_offload: bool = False,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dist.get_world_size() % (
|
||||
tp_size * pp_size
|
||||
@@ -208,6 +276,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
|
||||
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph)
|
||||
self.max_norm = max_norm
|
||||
|
||||
@property
|
||||
@@ -241,7 +316,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
|
||||
self.ddp_config)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
if self.precision in ['fp16', 'bf16']:
|
||||
|
Reference in New Issue
Block a user