[moe] add mixtral dp grad scaling when not all experts are activated

This commit is contained in:
botbw 2024-07-12 03:27:20 +00:00 committed by hxwang
parent 2f9bce6686
commit 1b15cc97f5
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
8 changed files with 98 additions and 42 deletions

View File

@ -141,6 +141,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# set ep_group after super init # set ep_group after super init
# TODO do it in a better way # TODO do it in a better way
self.shard_config.ep_group = self.ep_group self.shard_config.ep_group = self.ep_group
self.shard_config.moe_dp_group = self.moe_dp_group
self.shard_config.moe_tp_group = self.moe_tp_group
self.force_overlap_comm = force_overlap_comm self.force_overlap_comm = force_overlap_comm
@ -159,7 +161,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# create groups from submesh # create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage): for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is dp, axis 1 is tp, axis 2 is sp # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size) submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis # hardcode here since we only have 3 axis
@ -188,7 +190,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
assert self.moe_tp_group is None assert self.moe_tp_group is None
self.moe_tp_group = group self.moe_tp_group = group
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}") self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0])
def get_checkpoint_io(self) -> MoECheckpointIO: def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO( return MoECheckpointIO(

View File

@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
return torch.cumsum(inputs, dim=0) - 1 return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function): class EPGradScalerIn(torch.autograd.Function):
""" """
Scale the gradient back by the number of experts Scale the gradient back by the number of experts
because the batch size increases in the moe stage because the batch size increases in the moe stage
@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None: ctx.ep_size = ep_size
ctx.ep_size = ep_size
return inputs return inputs
@staticmethod @staticmethod
@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function):
return grad, None return grad, None
class MoeOutGradScaler(torch.autograd.Function): class EPGradScalerOut(torch.autograd.Function):
""" """
Scale the gradient by the number of experts Scale the gradient by the number of experts
because the batch size increases in the moe stage because the batch size increases in the moe stage
@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function):
return grad, None return grad, None
class DPGradScalerIn(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.activated_experts / ctx.moe_dp_size)
return grad, None, None
class DPGradScalerOut(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor:
assert activated_experts != 0, f"shouldn't be called when no expert is activated"
ctx.moe_dp_size = moe_dp_size
ctx.activated_experts = activated_experts
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.moe_dp_size != ctx.activated_experts:
grad.mul_(ctx.moe_dp_size / ctx.activated_experts)
return grad, None, None
def _all_to_all( def _all_to_all(
inputs: torch.Tensor, inputs: torch.Tensor,
input_split_sizes: Optional[List[int]] = None, input_split_sizes: Optional[List[int]] = None,

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer from colossalai.shardformer.layer.utils import Randomizer
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
Returns: Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
""" """
x = MoeInGradScaler.apply(x, self.ep_size) x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1) e = x.size(1)
h = x.size(-1) h = x.size(-1)
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape) x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous() x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size) x = EPGradScalerOut.apply(x, self.ep_size)
return x return x

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer from colossalai.shardformer.layer.utils import Randomizer
@ -118,7 +118,7 @@ class MLPExperts(nn.Module):
Returns: Returns:
torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
""" """
x = MoeInGradScaler.apply(x, self.ep_size) x = EPGradScalerIn.apply(x, self.ep_size)
e = x.size(1) e = x.size(1)
h = x.size(-1) h = x.size(-1)
@ -157,5 +157,5 @@ class MLPExperts(nn.Module):
x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
x = x.reshape(inshape) x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous() x = x.transpose(0, 1).contiguous()
x = MoeOutGradScaler.apply(x, self.ep_size) x = EPGradScalerOut.apply(x, self.ep_size)
return x return x

View File

@ -14,18 +14,23 @@ from transformers.models.mixtral.modeling_mixtral import (
from transformers.utils import is_flash_attn_2_available, logging from transformers.utils import is_flash_attn_2_available, logging
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from colossalai.shardformer.shard.utils import set_tensors_to_none from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None): def __init__(self, *args, **kwargs):
super().__init__(config) raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
self.setup_process_groups(ep_group, tp_group, moe_tp_group)
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
assert moe_tp_group is not None
def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None):
# setup ep group # setup ep group
self.ep_size = dist.get_world_size(ep_group) self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group) self.ep_rank = dist.get_rank(ep_group)
@ -40,7 +45,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
set_tensors_to_none(self.experts, exclude=set(held_experts)) set_tensors_to_none(self.experts, exclude=set(held_experts))
for p in self.experts.parameters(): for p in self.experts.parameters():
p.ep_group = ep_group set_moe_tensor_ep_group(p, ep_group)
# setup moe_dp group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = moe_dp_group.size()
# setup global tp group # setup global tp group
self.tp_group = tp_group self.tp_group = tp_group
@ -50,11 +59,12 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs
) -> "EPMixtralSparseMoeBlock": ) -> "EPMixtralSparseMoeBlock":
# TODO: better init
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(ep_group) module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
return module return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -76,36 +86,48 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_split_sizes = torch.zeros_like(input_split_sizes) output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
with torch.no_grad():
activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
if self.tp_group is not None and self.tp_group.size() > 1: if self.tp_group.size() > 1:
dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group) dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group)
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
# compute expert output # compute expert output
output_states = MoeInGradScaler.apply(output_states, self.ep_size) output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0: if output_states.size(0) > 0:
if self.num_experts_per_ep == 1: if self.num_experts_per_ep == 1:
# no need to split # no need to split
expert = self.experts[self.expert_start_idx] expert = self.experts[self.expert_start_idx]
output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
output_states = expert.w2(output_states) output_states = expert.w2(output_states)
output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
else: else:
output_states_splits = output_states.split(output_split_sizes.tolist()) output_states_splits = output_states.split(output_split_sizes.tolist())
output_states_list = [] output_states_list = []
for i, split_states in enumerate(output_states_splits): for i, split_states in enumerate(output_states_splits):
if split_states.size(0) == 0: if split_states.size(0) == 0:
continue continue
split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
split_states = expert.w2(split_states) split_states = expert.w2(split_states)
split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item())
output_states_list.append(split_states) output_states_list.append(split_states)
output_states = torch.cat(output_states_list) output_states = torch.cat(output_states_list)
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
if self.tp_group is not None and self.tp_group.size() > 1: if self.tp_group.size() > 1:
dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group) dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group)
recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx = torch.empty_like(selected_experts_idx)

View File

@ -76,18 +76,6 @@ class MixtralPolicy(Policy):
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
), ),
# SubModuleReplacementDescription( # TODO: enable moe tp parallel
# suffix="mlp.gate_proj",
# target_module=Linear1D_Col,
# ),
# SubModuleReplacementDescription(
# suffix="mlp.up_proj",
# target_module=Linear1D_Col,
# ),
# SubModuleReplacementDescription(
# suffix="mlp.down_proj",
# target_module=Linear1D_Row,
# ),
], ],
) )
@ -98,7 +86,7 @@ class MixtralPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="block_sparse_moe", suffix="block_sparse_moe",
target_module=EPMixtralSparseMoeBlock, target_module=EPMixtralSparseMoeBlock,
kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group}, kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group},
) )
], ],
policy=policy, policy=policy,

View File

@ -46,6 +46,9 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64 make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None
moe_tp_group: Optional[ProcessGroup] = None moe_tp_group: Optional[ProcessGroup] = None

View File

@ -18,8 +18,7 @@ NUM_BATCH=4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4 HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS=2 NUM_HEADS=2
TOP_K = 2 TOP_K = 1
def split_grad(grad, world_size): def split_grad(grad, world_size):
with torch.no_grad(): with torch.no_grad():
@ -96,7 +95,6 @@ def run_zero_with_original_model(stage: int, ep_size: int):
# check grad # check grad
name_to_p = {n: p for n, p in ddp_model.named_parameters()} name_to_p = {n: p for n, p in ddp_model.named_parameters()}
for n, p in zero_model.named_parameters(): for n, p in zero_model.named_parameters():
print(f"rank {dist.get_rank()} {n}")
zero_grad = zero_optimizer.get_param_grad(p) zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None: if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
@ -124,9 +122,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_ep_tp(world_size): def test_moe_ep_zero(world_size):
spawn(run_dist, world_size) spawn(run_dist, world_size)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_ep_tp(world_size=4) test_moe_ep_zero(world_size=4)