diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 95a9884a9..f102ddc01 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,5 +1,8 @@ -from .experts import Experts, FFNExperts +from .experts import Experts, FFNExperts, TPExperts from .layers import MoeLayer, Top1Router, Top2Router -from .utils import NormalNoiseGenerator +from .utils import NormalNoiseGenerator, build_ffn_experts -__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator'] +__all__ = [ + 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', + 'build_ffn_experts' +] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 662a907fc..a28c1cda8 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -15,6 +15,55 @@ except ImportError: print("If you want to activate cuda mode for MoE, please install with cuda_ext!") +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: + + if ctx is not None: + ctx.parallel_mode = parallel_mode + + comm_size = gpc.get_world_size(parallel_mode) + if comm_size == 1: + return inputs.unsqueeze(0) + + buffer_shape = (comm_size,) + inputs.shape + outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) + dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode)) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None + + +class ReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: + + if ctx is not None: + ctx.parallel_mode = parallel_mode + + comm_size = gpc.get_world_size(parallel_mode) + if comm_size == 1: + return inputs.squeeze(0) + + if not inputs.is_contiguous(): + inputs = inputs.contiguous() + + output_shape = inputs.shape[1:] + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) + dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode)) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: + return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None + + class AllToAll(torch.autograd.Function): """Dispatches input tensor [e, c, h] to all experts by all_to_all_single operation in torch.distributed. diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index ef9618c1a..797eb9c24 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -7,7 +7,16 @@ from colossalai.context import ParallelMode, seed from colossalai.utils import get_current_device -class Experts(nn.Module): +class MoeExperts(nn.Module): + + def __init__(self, comm: str): + super().__init__() + assert comm in {"all_to_all", "all_gather"}, \ + "This kind of communication has not been implemented yet.\n Please use Experts build function." + self.comm = comm + + +class Experts(MoeExperts): """A wrapper class to create experts. It will create E experts across the moe model parallel group, where E is the number of experts. Every expert is a instence of the class, 'expert' in initialization parameters. @@ -20,19 +29,22 @@ class Experts(nn.Module): """ def __init__(self, expert, num_experts, **expert_args): - super().__init__() + super().__init__("all_to_all") assert num_experts % moe_env.model_parallel_size == 0, \ "The number of experts should be divied by moe model size" num_local_experts = num_experts // moe_env.model_parallel_size + with seed(ParallelMode.MOE_MODEL): self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)]) - self.num_local_experts = num_local_experts + for exp in self.experts: for param in exp.parameters(): param.__setattr__('moe_param', True) + self.num_local_experts = num_local_experts + def forward(self, inputs): expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) expert_output = [] @@ -44,10 +56,10 @@ class Experts(nn.Module): return output -class FFNExperts(nn.Module): +class FFNExperts(MoeExperts): def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__() + super().__init__("all_to_all") assert num_experts % moe_env.model_parallel_size == 0, \ "The number of experts should be divied by moe model size" @@ -75,7 +87,7 @@ class FFNExperts(nn.Module): for param in self.parameters(): param.__setattr__('moe_param', True) - def forward(self, inputs): # x [g, el, c, h] + def forward(self, inputs): # inputs [g, el, c, h] el = inputs.size(1) h = inputs.size(-1) @@ -96,3 +108,58 @@ class FFNExperts(nn.Module): outputs = outputs.reshape(inshape) outputs = outputs.transpose(0, 1).contiguous() return outputs + + +class TPExperts(MoeExperts): + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_gather") + + assert d_ff % moe_env.model_parallel_size == 0, \ + "d_ff should be divied by moe model size" + + p_ff = d_ff // moe_env.model_parallel_size + + self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.MOE_MODEL): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + self.w1.__setattr__('moe_param', True) + self.w2.__setattr__('moe_param', True) + self.b1.__setattr__('moe_param', True) + + def forward(self, inputs): # inputs [g, e, c, h] + + e = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(e, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, inter, self.w2) + outputs = self.drop(out_model) # outputs [e, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 0abe7ac8c..71d54c298 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -8,7 +8,8 @@ from colossalai.core import global_context as gpc from colossalai.global_variables import moe_env from colossalai.context import ParallelMode from colossalai.utils import get_current_device -from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum +from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum +from .experts import MoeExperts from .utils import autocast_softmax @@ -198,7 +199,7 @@ class MoeLayer(nn.Module): :type experts: nn.Module """ - def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module): + def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): super().__init__() self.d_model = dim_model self.num_experts = num_experts @@ -207,8 +208,8 @@ class MoeLayer(nn.Module): self.experts = experts self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False - def expert_part(self, expert_input: torch.Tensor): - expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL) + def a2a_process(self, dispatch_data: torch.Tensor): + expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL) input_shape = expert_input.shape @@ -221,24 +222,42 @@ class MoeLayer(nn.Module): expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL) return expert_output + def tp_process(self, dispatch_data: torch.Tensor): + expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL) + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL) + return expert_out + def forward(self, inputs: torch.Tensor) -> torch.Tensor: tokens = inputs.reshape(-1, self.d_model) gate_output = self.gate(tokens) router_res = self.router(gate_output, self.cuda_mode) if self.cuda_mode: - logits, mask, dest_idx, ec = router_res - expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec) - expert_output = self.expert_part(expert_input) - ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec) + dispatch_data = MoeDispatch.apply(tokens, *router_res[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) else: - combine_weights, sec_mask = router_res - sec_mask_f = sec_mask.type_as(inputs) - expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - expert_output = self.expert_part(expert_input) + sec_mask_f = router_res[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + # dispatch_data [e, c, h] + if self.experts.comm == "all_to_all": + expert_output = self.a2a_process(dispatch_data) + elif self.experts.comm == "all_gather": + expert_output = self.tp_process(dispatch_data) + else: + raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " + "build function.") + # expert_output [e, c, h] + + if self.cuda_mode: + expert_output = expert_output.reshape(-1, self.d_model) + ans = MoeCombine.apply(expert_output, *router_res) + else: + combine_weights = router_res[0] combine_weights = combine_weights.view(combine_weights.shape[0], -1) expert_output = expert_output.view(-1, expert_output.shape[-1]) - ret = torch.matmul(combine_weights, expert_output) + ans = torch.matmul(combine_weights, expert_output) - ret = ret.reshape(inputs.shape) - return ret + ans = ans.reshape(inputs.shape) + return ans diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 060741c4f..4fa090662 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,6 +1,8 @@ import torch import torch.nn.functional as F from colossalai.utils import get_current_device +from colossalai.global_variables import moe_env +from .experts import FFNExperts, TPExperts class NormalNoiseGenerator: @@ -14,10 +16,9 @@ class NormalNoiseGenerator: """ def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device()) - ).rsample + self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, + device=get_current_device())).rsample def __call__(self, inputs: torch.Tensor): noisy = self.normal(inputs.shape) @@ -30,3 +31,13 @@ def autocast_softmax(inputs: torch.Tensor, dim: int): sm_input = inputs.to(torch.float32) if fp16_flag else inputs sm_output = F.softmax(sm_input, dim) return sm_output + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + moe_mp_size = moe_env.model_parallel_size + if num_experts % moe_mp_size == 0: + return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % moe_mp_size == 0: + return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.") diff --git a/model_zoo/moe/models.py b/model_zoo/moe/models.py index cffd837a4..277e377c7 100644 --- a/model_zoo/moe/models.py +++ b/model_zoo/moe/models.py @@ -4,7 +4,7 @@ import torch.nn as nn from colossalai.context import ParallelMode from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ WrappedDropout as Dropout, WrappedDropPath as DropPath -from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator +from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator from .util import moe_sa_args, moe_mlp_args from ..helper import TransformerLayer from colossalai.global_variables import moe_env @@ -110,7 +110,7 @@ class Widenet(nn.Module): noisy_func = NormalNoiseGenerator(num_experts) shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) - shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate) + shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate) # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] @@ -177,7 +177,7 @@ class ViTMoE(nn.Module): ffn = VanillaFFN(**moe_mlp_args( d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \ MoeLayer(dim_model=d_model, num_experts=num_experts, router=router, - experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)) + experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)) layer = TransformerLayer(att=sa, ffn=ffn, norm1=nn.LayerNorm(d_model, eps=1e-6), diff --git a/tests/test_moe/short_test.py b/tests/test_moe/short_test.py index 3b919345d..af77f4d6f 100644 --- a/tests/test_moe/short_test.py +++ b/tests/test_moe/short_test.py @@ -1,6 +1,4 @@ -import os from functools import partial -from pathlib import Path import pytest import torch import torch.nn as nn @@ -9,10 +7,10 @@ import colossalai from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Top2Router, MoeLayer +from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts +from colossalai.context.random import moe_set_seed from colossalai.global_variables import moe_env - BATCH_SIZE = 32 NUM_EXPERTS = 4 CONFIG = dict(parallel=dict(moe=dict(size=4))) @@ -24,17 +22,17 @@ def check_equal(A, B, atol=1e-06): def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - + moe_set_seed(42) # torch.set_printoptions(precision=30) torch.backends.cuda.matmul.allow_tf32 = False local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) torch.manual_seed(rs + local_rank) moe_env.reset_loss() - tokens = torch.randn(BATCH_SIZE, hidden_size, - dtype=data_type, device=get_current_device(), requires_grad=True) + tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) # print(f"tokens:\n{tokens}") router = Top2Router(1) - layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity()) + expert = Experts(nn.Identity, 4) + layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert) if data_type == torch.float16: layer = layer.half() layer.cuda_mode = False @@ -88,8 +86,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f @pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) def test_moe_top2(rs, hidden_size, data_type): world_size = 4 - run_func = partial(run_routing, world_size=world_size, port=free_port(), - rs=rs, hidden_size=hidden_size, data_type=data_type) + run_func = partial(run_routing, + world_size=world_size, + port=free_port(), + rs=rs, + hidden_size=hidden_size, + data_type=data_type) mp.spawn(run_func, nprocs=world_size)