mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[moe] support mixtral (#5309)
* [moe] add mixtral block for single expert * [moe] mixtral block fwd support uneven ep * [moe] mixtral block bwd support uneven ep * [moe] add mixtral moe layer * [moe] simplify replace * [meo] support save sharded mixtral * [meo] support load sharded mixtral * [meo] support save sharded optim * [meo] integrate moe manager into plug * [meo] fix optimizer load * [meo] fix mixtral layer
This commit is contained in:
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
@@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=self.real_dp_size,
|
||||
fixed_ep_size=ep_size,
|
||||
fixed_pp_size=pp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_info = MOE_MANAGER.get_info(0)[1]
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
|
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from .utils import has_index_file
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||
|
||||
__all__ = ["CheckpointIO"]
|
||||
|
||||
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
|
||||
if ctx.ep_size != 1:
|
||||
grad = grad / ctx.ep_size
|
||||
return grad, None
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
outputs_shape = list(inputs.shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
inputs = inputs.contiguous()
|
||||
outputs = outputs.contiguous()
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
return outputs, handle
|
||||
|
||||
|
||||
class AllToAllUneven(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
inputs,
|
||||
input_split_sizes=None,
|
||||
output_split_sizes=None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
ctx.input_split_sizes = input_split_sizes
|
||||
ctx.output_split_sizes = output_split_sizes
|
||||
ctx.group = group
|
||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs):
|
||||
return (
|
||||
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def all_to_all_uneven(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
@@ -26,3 +26,5 @@ class MoeParallelInfo:
|
||||
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
|
||||
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
|
||||
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
|
||||
self.ep_rank = self.pg.coordinate(self.ep_axis)
|
||||
self.dp_rank = self.pg.coordinate(self.dp_axis)
|
||||
|
@@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
def sync_moe_master_param(self):
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
@@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
||||
else:
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.copy_(working_moe_param)
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.master_to_working_param
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
|
Reference in New Issue
Block a user