[gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix
This commit is contained in:
flybird11111
2023-11-16 21:03:04 +08:00
committed by GitHub
parent b2ad0d9e8f
commit 3e02154710
10 changed files with 96 additions and 137 deletions

View File

@@ -10,6 +10,7 @@ import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
@@ -34,8 +35,7 @@ __all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
DP_AXIS = 0
TP_AXIS = 1
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
@@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
@@ -347,8 +347,8 @@ class GeminiPlugin(DPPluginBase):
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
extra_dp_size:int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
@@ -393,7 +393,7 @@ class GeminiPlugin(DPPluginBase):
max_norm=max_norm,
norm_type=norm_type,
)
self.enable_tensor_parallelism = enable_tensor_parallelism
self.enable_tensor_parallelism = tp_size > 1
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
@@ -402,12 +402,17 @@ class GeminiPlugin(DPPluginBase):
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.dp_size = dist.get_world_size() // self.tp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.tp_size = tp_size
self.extra_dp_size = extra_dp_size
world_size = dist.get_world_size()
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."
self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
@@ -458,7 +463,7 @@ class GeminiPlugin(DPPluginBase):
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(