diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 2a51344c3..05333fe96 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,9 +1,10 @@ -from .experts import Experts, FFNExperts, TPExperts -from .layers import MoeLayer, MoeModule -from .routers import MoeRouter, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts - -__all__ = [ - 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter' -] +from .checkpoint import load_moe_model, save_moe_model +from .experts import Experts, FFNExperts, TPExperts +from .layers import MoeLayer, MoeModule +from .routers import MoeRouter, Top1Router, Top2Router +from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts + +__all__ = [ + 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', + 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model' +] diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py new file mode 100644 index 000000000..efda1f222 --- /dev/null +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -0,0 +1,40 @@ +import torch +import torch.distributed as dist +import torch.nn as nn + +from .experts import MoeExperts + + +def save_moe_model(model: nn.Module, save_path: str): + state_dict = model.state_dict() + if dist.get_rank() == 0: + torch.save(state_dict, save_path) + dist.barrier() + + +def load_moe_model(model: nn.Module, load_path: str): + state_dict = torch.load(load_path) + + for prefix, module in model.named_modules(): + if prefix.endswith('.moe_layer.experts'): + # this module should be an Experts instance + assert isinstance(module, MoeExperts) + + ep_rank = dist.get_rank(module.dist_info.ep_group) + num_local = module.num_local_experts + for i in range(num_local): + expert_id = ep_rank * num_local + i + for name, _ in module.experts[i].named_parameters(): + cur_key = f'{prefix}.experts.{i}.{name}' + param_key = f'{prefix}.experts.{expert_id}.{name}' + load_param = state_dict[param_key] + state_dict[cur_key] = load_param + + for name, _ in module.experts[0].named_parameters(): + pop_pre = f'{prefix}.experts.' + pop_suf = f'.{name}' + for i in range(num_local, module.num_total_experts): + pop_key = f'{pop_pre}{i}{pop_suf}' + state_dict.pop(pop_key) + + model.load_state_dict(state_dict) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 055afded9..4fb9ad332 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -1,172 +1,203 @@ -import math - -import torch -import torch.nn as nn -from colossalai.context import ParallelMode, seed -from colossalai.utils import get_current_device -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.zero.init_ctx import no_shard_zero_decrator -from typing import Type - - -class MoeExperts(nn.Module): - """Basic class for experts in MoE. It stores what kind of communication expersts use - to exchange tokens, how many experts in a single GPU and parallel information such as - expert parallel size, data parallel size and their distributed communication groups. - """ - - def __init__(self, comm_name: str, num_experts: int): - super().__init__() - assert comm_name in {"all_to_all", "all_gather"}, \ - "This kind of communication has not been implemented yet.\n Please use Experts build function." - self.comm_name = comm_name - # Get the configuration of experts' deployment and parallel information from moe contex - self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) - - -@no_shard_zero_decrator(is_replicated=False) -class Experts(MoeExperts): - """A wrapper class to create experts. It will create E experts across the - moe model parallel group, where E is the number of experts. Every expert - is a instence of the class, 'expert' in initialization parameters. - - Args: - expert_cls (:class:`torch.nn.Module`): The class of all experts - num_experts (int): The number of experts - expert_args: Args used to initialize experts, the args could be found in corresponding expert class - """ - - def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): - super().__init__("all_to_all", num_experts) - - # Use seed to make every expert different from others - with seed(ParallelMode.TENSOR): - self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) - - # Attach parallel information for all parameters in Experts - for exp in self.experts: - for param in exp.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs: torch.Tensor): - # Split inputs for each expert - expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) - expert_output = [] - - # Get outputs from each expert - for i in range(self.num_local_experts): - expert_output.append(self.experts[i](expert_input[i])) - - # Concatenate all outputs together - output = torch.cat(expert_output, dim=1).contiguous() - return output - - -class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts. - """ - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_to_all", num_experts) - - self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - for param in self.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, el, c, h] - - el = inputs.size(1) - h = inputs.size(-1) - - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(el, -1, h) - - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] - - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs - - -class TPExperts(MoeExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divied by maximum expert parallel size or - maximum expert parallel size can't be divied by the number of experts. - """ - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - - assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ - "d_ff should be divied by maximum expert parallel size" - - p_ff = d_ff // MOE_CONTEXT.max_ep_size - - self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - self.w1.__setattr__('moe_info', self.dist_info) - self.w2.__setattr__('moe_info', self.dist_info) - self.b1.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, e, c, h] - - e = inputs.size(1) - h = inputs.size(-1) - - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(e, -1, h) - - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] - - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] +import math +from copy import deepcopy +from typing import Type + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.context import ParallelMode, seed +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device +from colossalai.zero.init_ctx import no_shard_zero_decrator + + +class MoeExperts(nn.Module): + """Basic class for experts in MoE. It stores what kind of communication expersts use + to exchange tokens, how many experts in a single GPU and parallel information such as + expert parallel size, data parallel size and their distributed communication groups. + """ + + def __init__(self, comm_name: str, num_experts: int): + super().__init__() + assert comm_name in {"all_to_all", "all_gather"}, \ + "This kind of communication has not been implemented yet.\n Please use Experts build function." + self.comm_name = comm_name + self.num_total_experts = num_experts + # Get the configuration of experts' deployment and parallel information from moe contex + self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) + + +@no_shard_zero_decrator(is_replicated=False) +class Experts(MoeExperts): + """A wrapper class to create experts. It will create E experts across the + moe model parallel group, where E is the number of experts. Every expert + is a instence of the class, 'expert' in initialization parameters. + + Args: + expert_cls (:class:`torch.nn.Module`): The class of all experts + num_experts (int): The number of experts + expert_args: Args used to initialize experts, the args could be found in corresponding expert class + """ + + def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): + super().__init__("all_to_all", num_experts) + + # Use seed to make every expert different from others + with seed(ParallelMode.TENSOR): + self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) + + # Attach parallel information for all parameters in Experts + for exp in self.experts: + for param in exp.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs: torch.Tensor): + # Split inputs for each expert + expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) + expert_output = [] + + # Get outputs from each expert + for i in range(self.num_local_experts): + expert_output.append(self.experts[i](expert_input[i])) + + # Concatenate all outputs together + output = torch.cat(expert_output, dim=1).contiguous() + return output + + def state_dict(self, destination=None, prefix='', keep_vars=False): + assert keep_vars == False, "Only support keep_vars=False now" + dp_rank = dist.get_rank(self.dist_info.dp_group) + ep_rank = dist.get_rank(self.dist_info.ep_group) + submodule_dict = dict() + example_submodule = None + for name, subm in self.experts.named_modules(): + if subm is self.experts: + continue + module_number = self.num_local_experts * ep_rank + int(name) + submodule_dict[module_number] = subm + example_submodule = subm + + if dp_rank == 0: + local_prefix = prefix + 'experts.' + buffer_module = deepcopy(example_submodule) + for i in range(self.num_total_experts): + source_rank = i // self.num_local_experts + current_prefix = local_prefix + str(i) + '.' + comm_module = submodule_dict.get(i, buffer_module) + for name, param in comm_module.named_parameters(): + dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) + if ep_rank == 0: + destination[current_prefix + name] = param.data.cpu() + + dist.barrier() + + +class FFNExperts(MoeExperts): + """Use torch.bmm to speed up for multiple experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_to_all", num_experts) + + self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + for param in self.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, el, c, h] + + el = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(el, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + with seed(ParallelMode.TENSOR): + outputs = self.drop(out_model) # outputs [el, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs + + +class TPExperts(MoeExperts): + """Use tensor parallelism to split each expert evenly, which can deploy experts in + case that the number of experts can't be divied by maximum expert parallel size or + maximum expert parallel size can't be divied by the number of experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_gather", MOE_CONTEXT.max_ep_size) + + assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ + "d_ff should be divied by maximum expert parallel size" + + p_ff = d_ff // MOE_CONTEXT.max_ep_size + + self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + self.w1.__setattr__('moe_info', self.dist_info) + self.w2.__setattr__('moe_info', self.dist_info) + self.b1.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, e, c, h] + + e = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(e, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + outputs = self.drop(out_model) # outputs [e, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 259f53f1a..0969eb818 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -1,203 +1,210 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.utils import get_current_device -from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \ - ReduceScatter, MoeDispatch, MoeCombine -from colossalai.nn.layer.moe.experts import MoeExperts, Experts -from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator -from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router -from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator -from typing import Optional, Type, Tuple - - -@no_shard_zero_decrator(is_replicated=True) -class MoeLayer(nn.Module): - """A MoE layer, that puts its input tensor to its gate and uses the output logits - to router all tokens, is mainly used to exchange all tokens for every expert across - the moe tensor group by all to all comunication. Then it will get the output of all - experts and exchange the output. At last returns the output of the moe system. - - Args: - dim_model (int): Dimension of model. - num_experts (int): The number of experts. - router (MoeRouter): Instance of router used in routing. - experts (MoeExperts): Instance of experts generated by Expert. - """ - - def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): - super().__init__() - self.d_model = dim_model - self.num_experts = num_experts - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) - self.router: MoeRouter = router - self.experts: MoeExperts = experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - self.ep_group = experts.dist_info.ep_group - self.ep_size = experts.dist_info.ep_size - self.num_local_experts = experts.num_local_experts - - nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) - - def a2a_process(self, dispatch_data: torch.Tensor): - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output - - def tp_process(self, dispatch_data: torch.Tensor): - expert_in = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group) - return expert_out - - def forward(self, inputs: torch.Tensor) -> Tuple: - # reshape the input tokens - tokens = inputs.reshape(-1, self.d_model) - - # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) - - # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) - - if self.use_kernel: - dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) - dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) - else: - sec_mask_f = route_result_list[1].type_as(inputs) - dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - - # dispatch_data [e, c, h] - if self.experts.comm_name == "all_to_all": - expert_output = self.a2a_process(dispatch_data) - elif self.experts.comm_name == "all_gather": - expert_output = self.tp_process(dispatch_data) - else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") - # expert_output [e, c, h] - if self.use_kernel: - expert_output = expert_output.reshape(-1, self.d_model) - ans = MoeCombine.apply(expert_output, *route_result_list) - else: - combine_weights = route_result_list[0].type_as(inputs) - combine_weights = combine_weights.view(combine_weights.shape[0], -1) - expert_output = expert_output.view(-1, expert_output.shape[-1]) - ans = torch.matmul(combine_weights, expert_output) - - ans = ans.reshape(inputs.shape) - l_aux = self.router.pop_routing_loss() - return ans, l_aux - - -class MoeModule(nn.Module): - """A class for users to create MoE modules in their models. - - Args: - dim_model (int): Hidden dimension of training model - num_experts (int): The number experts - top_k (int, optional): The number of experts for dispatchment of each token - capacity_factor_train (float, optional): Capacity factor in routing during training - capacity_factor_eval (float, optional): Capacity factor in routing during evaluation - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. - 'Jitter' can be found in `Switch Transformer paper`_. - 'Gaussian' can be found in `ViT-MoE paper`_. - drop_tks (bool, optional): Whether drops tokens in evaluation - use_residual (bool, optional): Makes this MoE layer a Residual MoE. - More information can be found in `Microsoft paper`_. - residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE - expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer - expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given - expert_args (optional): The args of expert when no instance is given - - .. _Switch Transformer paper: - https://arxiv.org/abs/2101.03961 - .. _ViT-MoE paper: - https://arxiv.org/abs/2106.05974 - .. _Microsoft paper: - https://arxiv.org/abs/2201.05596 - """ - - def __init__(self, - dim_model: int, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args): - super().__init__() - - noisy_func = None - if noisy_policy is not None: - if noisy_policy == 'Jitter': - noisy_func = UniformNoiseGenerator() - elif noisy_policy == 'Gaussian': - noisy_func = NormalNoiseGenerator(num_experts) - else: - raise NotImplementedError("Unsupported input noisy policy") - - if top_k == 1: - moe_router_cls = Top1Router - elif top_k == 2: - moe_router_cls = Top2Router - else: - raise NotImplementedError("top_k > 2 is not supported yet") - - self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.use_residual = use_residual - if use_residual: - if residual_instance is not None: - self.residual_module = residual_instance - else: - assert expert_cls is not None, \ - "Expert class can't be None when residual instance is not given" - self.residual_module = expert_cls(**expert_args) - - with no_shard_zero_context(): - self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) - - if expert_instance is not None: - self.experts = expert_instance - else: - assert expert_cls is not None, \ - "Expert class can't be None when experts instance is not given" - self.experts = Experts(expert_cls, num_experts, **expert_args) - - self.moe_layer = MoeLayer(dim_model=dim_model, - num_experts=num_experts, - router=self.moe_router, - experts=self.experts) - - def forward(self, inputs: torch.Tensor): - moe_output, l_aux = self.moe_layer(inputs) - - if self.use_residual: - residual_output = self.residual_module(inputs) - combine_coef = self.residual_combine(inputs) - combine_coef = F.softmax(combine_coef, dim=-1) - output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] - else: - output = moe_output - - return output, l_aux +import math +from typing import Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import ( + COL_MOE_KERNEL_FLAG, + AllGather, + AllToAll, + MoeCombine, + MoeDispatch, + ReduceScatter, +) +from colossalai.nn.layer.moe.experts import Experts, MoeExperts +from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router +from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator +from colossalai.utils import get_current_device +from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator + + +@no_shard_zero_decrator(is_replicated=True) +class MoeLayer(nn.Module): + """A MoE layer, that puts its input tensor to its gate and uses the output logits + to router all tokens, is mainly used to exchange all tokens for every expert across + the moe tensor group by all to all comunication. Then it will get the output of all + experts and exchange the output. At last returns the output of the moe system. + + Args: + dim_model (int): Dimension of model. + num_experts (int): The number of experts. + router (MoeRouter): Instance of router used in routing. + experts (MoeExperts): Instance of experts generated by Expert. + """ + + def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): + super().__init__() + self.d_model = dim_model + self.num_experts = num_experts + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) + self.router: MoeRouter = router + self.experts: MoeExperts = experts + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + self.ep_group = experts.dist_info.ep_group + self.ep_size = experts.dist_info.ep_size + self.num_local_experts = experts.num_local_experts + + nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) + + def a2a_process(self, dispatch_data: torch.Tensor): + expert_input = AllToAll.apply(dispatch_data, self.ep_group) + input_shape = expert_input.shape + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) + expert_output = self.experts(expert_input) + expert_output = expert_output.reshape(input_shape) + expert_output = AllToAll.apply(expert_output, self.ep_group) + return expert_output + + def tp_process(self, dispatch_data: torch.Tensor): + expert_in = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out + + def forward(self, inputs: torch.Tensor) -> Tuple: + # reshape the input tokens + tokens = inputs.reshape(-1, self.d_model) + + # the data type of the inputs in the gating should be fp32 + fp32_input = tokens.to(torch.float) + fp32_weight = self.gate_weight.to(torch.float) + gate_output = F.linear(fp32_input, fp32_weight) + + # the result from the router + route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + + if self.use_kernel: + dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) + else: + sec_mask_f = route_result_list[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + # dispatch_data [e, c, h] + if self.experts.comm_name == "all_to_all": + expert_output = self.a2a_process(dispatch_data) + elif self.experts.comm_name == "all_gather": + expert_output = self.tp_process(dispatch_data) + else: + raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " + "build function.") + # expert_output [e, c, h] + if self.use_kernel: + expert_output = expert_output.reshape(-1, self.d_model) + ans = MoeCombine.apply(expert_output, *route_result_list) + else: + combine_weights = route_result_list[0].type_as(inputs) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans = torch.matmul(combine_weights, expert_output) + + ans = ans.reshape(inputs.shape) + l_aux = self.router.pop_routing_loss() + return ans, l_aux + + +class MoeModule(nn.Module): + """A class for users to create MoE modules in their models. + + Args: + dim_model (int): Hidden dimension of training model + num_experts (int): The number experts + top_k (int, optional): The number of experts for dispatchment of each token + capacity_factor_train (float, optional): Capacity factor in routing during training + capacity_factor_eval (float, optional): Capacity factor in routing during evaluation + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. + 'Jitter' can be found in `Switch Transformer paper`_. + 'Gaussian' can be found in `ViT-MoE paper`_. + drop_tks (bool, optional): Whether drops tokens in evaluation + use_residual (bool, optional): Makes this MoE layer a Residual MoE. + More information can be found in `Microsoft paper`_. + residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE + expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer + expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given + expert_args (optional): The args of expert when no instance is given + + .. _Switch Transformer paper: + https://arxiv.org/abs/2101.03961 + .. _ViT-MoE paper: + https://arxiv.org/abs/2106.05974 + .. _Microsoft paper: + https://arxiv.org/abs/2201.05596 + """ + + def __init__(self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args): + super().__init__() + + noisy_func = None + if noisy_policy is not None: + if noisy_policy == 'Jitter': + noisy_func = UniformNoiseGenerator() + elif noisy_policy == 'Gaussian': + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + + if top_k == 1: + moe_router_cls = Top1Router + elif top_k == 2: + moe_router_cls = Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + + self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.use_residual = use_residual + if use_residual: + if residual_instance is not None: + self.residual_module = residual_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when residual instance is not given" + self.residual_module = expert_cls(**expert_args) + + with no_shard_zero_context(): + self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) + + if expert_instance is not None: + my_experts = expert_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when experts instance is not given" + my_experts = Experts(expert_cls, num_experts, **expert_args) + + self.moe_layer = MoeLayer(dim_model=dim_model, + num_experts=num_experts, + router=self.moe_router, + experts=my_experts) + + def forward(self, inputs: torch.Tensor): + moe_output, l_aux = self.moe_layer(inputs) + + if self.use_residual: + residual_output = self.residual_module(inputs) + combine_coef = self.residual_combine(inputs) + combine_coef = F.softmax(combine_coef, dim=-1) + output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] + else: + output = moe_output + + return output, l_aux diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py new file mode 100644 index 000000000..f99e74ea5 --- /dev/null +++ b/tests/test_moe/test_moe_checkpoint.py @@ -0,0 +1,54 @@ +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer.moe import load_moe_model, save_moe_model +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_tensor.common_utils import debug_print +from tests.test_zero.common import CONFIG + + +def exam_moe_checkpoint(): + with ColoInitContext(device=get_current_device()): + model = MoeModel(checkpoint=True) + save_moe_model(model, 'temp_path.pth') + + with ColoInitContext(device=get_current_device()): + other_model = MoeModel(checkpoint=True) + load_moe_model(other_model, 'temp_path.pth') + + state_0 = model.state_dict() + state_1 = other_model.state_dict() + for k, v in state_0.items(): + u = state_1.get(k) + assert torch.equal(u.data, v.data) + + if dist.get_rank() == 0: + os.remove('temp_path.pth') + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_checkpoint() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_checkpoint(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_checkpoint(world_size=4)