mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[gemini] gemini support extra-dp (#5043)
* support ddp * fix * fix * fix fix * support ddp * fix * fix * fix fix * simplify tests * fix * fix * fix fix fix * fix
This commit is contained in:
@@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
@@ -34,8 +35,7 @@ __all__ = ["GeminiPlugin"]
|
||||
SUPPORTED_PRECISION = ["fp16", "bf16"]
|
||||
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
|
||||
|
||||
DP_AXIS = 0
|
||||
TP_AXIS = 1
|
||||
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
@@ -304,8 +304,8 @@ class GeminiPlugin(DPPluginBase):
|
||||
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
|
||||
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
|
||||
norm_type (float, optional): norm_type used for `clip_grad_norm`.
|
||||
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
|
||||
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
|
||||
tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
|
||||
extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
@@ -347,8 +347,8 @@ class GeminiPlugin(DPPluginBase):
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
enable_tensor_parallelism: bool = False,
|
||||
tp_size: int = 1,
|
||||
extra_dp_size:int = 1,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
@@ -393,7 +393,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
)
|
||||
self.enable_tensor_parallelism = enable_tensor_parallelism
|
||||
self.enable_tensor_parallelism = tp_size > 1
|
||||
self.enable_all_optimization = enable_all_optimization
|
||||
self.enable_fused_normalization = enable_fused_normalization
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
@@ -402,12 +402,17 @@ class GeminiPlugin(DPPluginBase):
|
||||
self.enable_sequence_overlap = enable_sequence_overlap
|
||||
self.verbose = verbose
|
||||
|
||||
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
|
||||
self.dp_size = dist.get_world_size() // self.tp_size
|
||||
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.tp_size = tp_size
|
||||
self.extra_dp_size = extra_dp_size
|
||||
world_size = dist.get_world_size()
|
||||
self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
|
||||
assert world_size == (self.tp_size * self.extra_dp_size) * self.zero_size, f"The global group size can't be evenly divided by the subgroup size."
|
||||
|
||||
self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
|
||||
self.zero_group = self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
|
||||
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
enable_tensor_parallelism=self.enable_tensor_parallelism,
|
||||
@@ -458,7 +463,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
shardformer = ShardFormer(self.shard_config)
|
||||
model, _ = shardformer.optimize(model)
|
||||
|
||||
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
|
||||
model = GeminiDDP(model, **self.gemini_config, zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose)
|
||||
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(
|
||||
|
Reference in New Issue
Block a user