[moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
This commit is contained in:
Hongxin Liu
2024-01-25 15:48:46 +08:00
committed by ver217
parent c904d2ae99
commit da39d21b71
14 changed files with 996 additions and 550 deletions

View File

@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoECheckpintIO
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self,
tp_size: int,
pp_size: int,
ep_size: int,
extra_dp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
@@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=self.real_dp_size,
fixed_ep_size=ep_size,
fixed_pp_size=pp_size,
use_ep_inside=use_ep_inside,
)
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.ep_size = ep_size
self.moe_info = MOE_MANAGER.get_info(0)[1]
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload

View File

@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
from .utils import has_index_file
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"]
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
else:
self.load_unsharded_model(model, checkpoint, strict)
return origin_model

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
import torch
import torch.distributed as dist
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None
def _all_to_all(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
outputs_shape = list(inputs.shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
class AllToAllUneven(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inputs,
input_split_sizes=None,
output_split_sizes=None,
group=None,
overlap: bool = False,
):
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
@staticmethod
def backward(ctx: Any, *grad_outputs):
return (
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
None,
None,
None,
None,
)
def all_to_all_uneven(
inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None,
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
):
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)

View File

@@ -26,3 +26,5 @@ class MoeParallelInfo:
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
self.ep_rank = self.pg.coordinate(self.ep_axis)
self.dp_rank = self.pg.coordinate(self.dp_axis)

View File

@@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def sync_moe_master_param(self):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.master_to_working_param
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}