mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Feature] MoE Ulysses Support (#5918)
* moe sp support * moe sp bug solve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
|
||||
@@ -22,6 +24,8 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import cast_to_distributed
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
|
||||
@@ -114,21 +118,25 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
|
||||
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size * self.sp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
||||
raise ValueError(
|
||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||
)
|
||||
|
||||
self._init_moe_param_comm()
|
||||
# self._init_moe_param_comm()
|
||||
|
||||
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
||||
|
||||
# set ep_group after super init
|
||||
# TODO do it in a better way
|
||||
self.moe_dp_group = self.pp_group
|
||||
self.ep_group = self.pp_group
|
||||
self.moe_tp_group = self.pp_group
|
||||
|
||||
self.shard_config.ep_group = self.ep_group
|
||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||
@@ -205,15 +213,32 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
param_info = get_param_info(optimizer)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
self.zero_stage
|
||||
deepcopy(self.zero_config)
|
||||
# Replace with distributed implementation if exists
|
||||
optimizer = cast_to_distributed(optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1
|
||||
and self.pp_size == 1
|
||||
and self.enable_sequence_parallelism
|
||||
and self.sequence_parallelism_mode == "all_to_all"
|
||||
)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
dp_group=dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=self.use_ddp,
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
@@ -224,6 +249,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
if self.zero_stage == 0:
|
||||
is_zero = False
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
@@ -236,7 +262,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
)
|
||||
else:
|
||||
if not (self.dp_size > 1 or self.moe_dp_size > 1):
|
||||
@@ -244,6 +276,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
||||
)
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = MoeHybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
@@ -262,4 +295,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
||||
# Setup optimizers that require global states
|
||||
optim = optimizer.optim
|
||||
if isinstance(optim, DistributedOptim):
|
||||
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
|
||||
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
|
||||
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
Reference in New Issue
Block a user