mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 21:55:46 +00:00
[MOE] polish moe_env (#467)
This commit is contained in:
parent
bccbc15861
commit
aff9d354f7
@ -4,4 +4,4 @@
|
|||||||
from colossalai.context import ParallelContext, MoeContext
|
from colossalai.context import ParallelContext, MoeContext
|
||||||
|
|
||||||
global_context = ParallelContext.get_instance()
|
global_context = ParallelContext.get_instance()
|
||||||
moe_context = MoeContext.get_instance()
|
MOE_CONTEXT = MoeContext.get_instance()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from colossalai.core import global_context as gpc, moe_context as moe_env
|
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||||
from colossalai.registry import GRADIENT_HANDLER
|
from colossalai.registry import GRADIENT_HANDLER
|
||||||
from colossalai.utils.moe import get_moe_epsize_param_dict
|
from colossalai.utils.moe import get_moe_epsize_param_dict
|
||||||
from ._base_gradient_handler import BaseGradientHandler
|
from ._base_gradient_handler import BaseGradientHandler
|
||||||
@ -30,5 +30,5 @@ class MoeGradientHandler(BaseGradientHandler):
|
|||||||
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
||||||
for ep_size in param_dict:
|
for ep_size in param_dict:
|
||||||
if ep_size != 1 and ep_size != moe_env.world_size:
|
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||||
bucket_allreduce(param_list=param_dict[ep_size], group=moe_env.information[ep_size].dp_group)
|
bucket_allreduce(param_list=param_dict[ep_size], group=MOE_CONTEXT.information[ep_size].dp_group)
|
||||||
|
@ -4,11 +4,11 @@ from torch import Tensor
|
|||||||
from typing import Any, Tuple, Optional
|
from typing import Any, Tuple, Optional
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
U_CUDA_MODE = False
|
COL_MOE_KERNEL_FLAG = False
|
||||||
try:
|
try:
|
||||||
import colossal_moe_cuda
|
import colossal_moe_cuda
|
||||||
|
|
||||||
U_CUDA_MODE = True
|
COL_MOE_KERNEL_FLAG = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||||
|
|
||||||
@ -17,7 +17,6 @@ class AllGather(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
|
|
||||||
if ctx is not None:
|
if ctx is not None:
|
||||||
ctx.comm_grp = group
|
ctx.comm_grp = group
|
||||||
|
|
||||||
@ -40,7 +39,6 @@ class ReduceScatter(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
|
|
||||||
if ctx is not None:
|
if ctx is not None:
|
||||||
ctx.comm_grp = group
|
ctx.comm_grp = group
|
||||||
|
|
||||||
@ -149,7 +147,7 @@ class MoeCombine(torch.autograd.Function):
|
|||||||
def moe_cumsum(inputs: Tensor):
|
def moe_cumsum(inputs: Tensor):
|
||||||
dim0 = inputs.size(0)
|
dim0 = inputs.size(0)
|
||||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||||
if flag and U_CUDA_MODE:
|
if flag and COL_MOE_KERNEL_FLAG:
|
||||||
return colossal_moe_cuda.cumsum_sub_one(inputs)
|
return colossal_moe_cuda.cumsum_sub_one(inputs)
|
||||||
else:
|
else:
|
||||||
return torch.cumsum(inputs, dim=0) - 1
|
return torch.cumsum(inputs, dim=0) - 1
|
||||||
|
@ -2,18 +2,24 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.global_variables import moe_env
|
|
||||||
from colossalai.context import ParallelMode, seed
|
from colossalai.context import ParallelMode, seed
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.core import MOE_CONTEXT
|
||||||
|
|
||||||
|
|
||||||
class MoeExperts(nn.Module):
|
class MoeExperts(nn.Module):
|
||||||
|
"""Basic class for experts in MoE. It stores what kind of communication expersts use
|
||||||
|
to exchange tokens, how many experts in a single GPU and parallel information such as
|
||||||
|
expert parallel size, data parallel size and their distributed communication groups.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, comm: str):
|
def __init__(self, comm_name: str, num_experts: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert comm in {"all_to_all", "all_gather"}, \
|
assert comm_name in {"all_to_all", "all_gather"}, \
|
||||||
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||||
self.comm = comm
|
self.comm_name = comm_name
|
||||||
|
# Get the configuration of experts' deployment and parallel information from moe contex
|
||||||
|
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
||||||
|
|
||||||
|
|
||||||
class Experts(MoeExperts):
|
class Experts(MoeExperts):
|
||||||
@ -29,53 +35,48 @@ class Experts(MoeExperts):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, expert, num_experts, **expert_args):
|
def __init__(self, expert, num_experts, **expert_args):
|
||||||
super().__init__("all_to_all")
|
super().__init__("all_to_all", num_experts)
|
||||||
|
|
||||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
# Use seed to make every expert different from others
|
||||||
"The number of experts should be divied by moe model size"
|
with seed(ParallelMode.TENSOR):
|
||||||
|
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(self.num_local_experts)])
|
||||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
|
||||||
|
|
||||||
with seed(ParallelMode.MOE_MODEL):
|
|
||||||
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
|
|
||||||
|
|
||||||
|
# Attach parallel information for all parameters in Experts
|
||||||
for exp in self.experts:
|
for exp in self.experts:
|
||||||
for param in exp.parameters():
|
for param in exp.parameters():
|
||||||
param.__setattr__('moe_param', True)
|
param.__setattr__('moe_info', self.dist_info)
|
||||||
|
|
||||||
self.num_local_experts = num_local_experts
|
def forward(self, inputs: torch.Tensor):
|
||||||
|
# Split inputs for each expert
|
||||||
def forward(self, inputs):
|
|
||||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
||||||
expert_output = []
|
expert_output = []
|
||||||
|
|
||||||
|
# Get outputs from each expert
|
||||||
for i in range(self.num_local_experts):
|
for i in range(self.num_local_experts):
|
||||||
expert_output.append(self.experts[i](expert_input[i]))
|
expert_output.append(self.experts[i](expert_input[i]))
|
||||||
|
|
||||||
|
# Concatenate all outputs together
|
||||||
output = torch.cat(expert_output, dim=1).contiguous()
|
output = torch.cat(expert_output, dim=1).contiguous()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class FFNExperts(MoeExperts):
|
class FFNExperts(MoeExperts):
|
||||||
|
"""Use torch.bmm to speed up for multiple experts.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||||
super().__init__("all_to_all")
|
super().__init__("all_to_all", num_experts)
|
||||||
|
|
||||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
|
||||||
"The number of experts should be divied by moe model size"
|
self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
|
||||||
|
|
||||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
|
||||||
|
self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
|
||||||
self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device()))
|
|
||||||
self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device()))
|
|
||||||
|
|
||||||
self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device()))
|
|
||||||
self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device()))
|
|
||||||
|
|
||||||
s1 = math.sqrt(0.1 / d_model)
|
s1 = math.sqrt(0.1 / d_model)
|
||||||
s2 = math.sqrt(0.1 / d_ff)
|
s2 = math.sqrt(0.1 / d_ff)
|
||||||
|
|
||||||
with seed(ParallelMode.MOE_MODEL):
|
with seed(ParallelMode.TENSOR):
|
||||||
nn.init.trunc_normal_(self.w1, std=s1)
|
nn.init.trunc_normal_(self.w1, std=s1)
|
||||||
nn.init.trunc_normal_(self.b1, std=s1)
|
nn.init.trunc_normal_(self.b1, std=s1)
|
||||||
nn.init.trunc_normal_(self.w2, std=s2)
|
nn.init.trunc_normal_(self.w2, std=s2)
|
||||||
@ -85,7 +86,7 @@ class FFNExperts(MoeExperts):
|
|||||||
self.drop = nn.Dropout(p=drop_rate)
|
self.drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.__setattr__('moe_param', True)
|
param.__setattr__('moe_info', self.dist_info)
|
||||||
|
|
||||||
def forward(self, inputs): # inputs [g, el, c, h]
|
def forward(self, inputs): # inputs [g, el, c, h]
|
||||||
|
|
||||||
@ -99,9 +100,9 @@ class FFNExperts(MoeExperts):
|
|||||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||||
out_act = self.act(out_ff)
|
out_act = self.act(out_ff)
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
inter = self.drop(out_act)
|
out_inter = self.drop(out_act)
|
||||||
|
|
||||||
out_model = torch.baddbmm(self.b2, inter, self.w2)
|
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
outputs = self.drop(out_model) # outputs [el, gc, h]
|
outputs = self.drop(out_model) # outputs [el, gc, h]
|
||||||
|
|
||||||
@ -111,14 +112,18 @@ class FFNExperts(MoeExperts):
|
|||||||
|
|
||||||
|
|
||||||
class TPExperts(MoeExperts):
|
class TPExperts(MoeExperts):
|
||||||
|
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
|
||||||
|
case that the number of experts can't be divied by maximum expert parallel size or
|
||||||
|
maximum expert parallel size can't be divied by the number of experts.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||||
super().__init__("all_gather")
|
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
|
||||||
|
|
||||||
assert d_ff % moe_env.model_parallel_size == 0, \
|
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
|
||||||
"d_ff should be divied by moe model size"
|
"d_ff should be divied by maximum expert parallel size"
|
||||||
|
|
||||||
p_ff = d_ff // moe_env.model_parallel_size
|
p_ff = d_ff // MOE_CONTEXT.max_ep_size
|
||||||
|
|
||||||
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
||||||
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
||||||
@ -129,7 +134,7 @@ class TPExperts(MoeExperts):
|
|||||||
s1 = math.sqrt(0.1 / d_model)
|
s1 = math.sqrt(0.1 / d_model)
|
||||||
s2 = math.sqrt(0.1 / d_ff)
|
s2 = math.sqrt(0.1 / d_ff)
|
||||||
|
|
||||||
with seed(ParallelMode.MOE_MODEL):
|
with seed(ParallelMode.TENSOR):
|
||||||
nn.init.trunc_normal_(self.w1, std=s1)
|
nn.init.trunc_normal_(self.w1, std=s1)
|
||||||
nn.init.trunc_normal_(self.b1, std=s1)
|
nn.init.trunc_normal_(self.b1, std=s1)
|
||||||
nn.init.trunc_normal_(self.w2, std=s2)
|
nn.init.trunc_normal_(self.w2, std=s2)
|
||||||
@ -139,9 +144,9 @@ class TPExperts(MoeExperts):
|
|||||||
self.act = nn.GELU() if activation is None else activation
|
self.act = nn.GELU() if activation is None else activation
|
||||||
self.drop = nn.Dropout(p=drop_rate)
|
self.drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
self.w1.__setattr__('moe_param', True)
|
self.w1.__setattr__('moe_info', self.dist_info)
|
||||||
self.w2.__setattr__('moe_param', True)
|
self.w2.__setattr__('moe_info', self.dist_info)
|
||||||
self.b1.__setattr__('moe_param', True)
|
self.b1.__setattr__('moe_info', self.dist_info)
|
||||||
|
|
||||||
def forward(self, inputs): # inputs [g, e, c, h]
|
def forward(self, inputs): # inputs [g, e, c, h]
|
||||||
|
|
||||||
@ -155,9 +160,9 @@ class TPExperts(MoeExperts):
|
|||||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||||
out_act = self.act(out_ff)
|
out_act = self.act(out_ff)
|
||||||
with seed(ParallelMode.TENSOR):
|
with seed(ParallelMode.TENSOR):
|
||||||
inter = self.drop(out_act)
|
out_inter = self.drop(out_act)
|
||||||
|
|
||||||
out_model = torch.baddbmm(self.b2, inter, self.w2)
|
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||||
outputs = self.drop(out_model) # outputs [e, gc, h]
|
outputs = self.drop(out_model) # outputs [e, gc, h]
|
||||||
|
|
||||||
outputs = outputs.reshape(inshape)
|
outputs = outputs.reshape(inshape)
|
||||||
|
@ -4,14 +4,13 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import MOE_CONTEXT
|
||||||
from colossalai.global_variables import moe_env
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||||
from .experts import MoeExperts
|
from .experts import MoeExperts
|
||||||
from .utils import autocast_softmax
|
from .utils import autocast_softmax
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
class Top1Router(nn.Module):
|
class Top1Router(nn.Module):
|
||||||
@ -19,8 +18,8 @@ class Top1Router(nn.Module):
|
|||||||
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
||||||
of Google.
|
of Google.
|
||||||
|
|
||||||
:param capacity_factor_train: Capacity factor in routing of training
|
:param capacity_factor_train: Capacity factor in routing during training
|
||||||
:param capacity_factor_eval: Capacity factor in routing of evaluation
|
:param capacity_factor_eval: Capacity factor in routing during evaluation
|
||||||
:param min_capacity: The minimum number of the capacity of each expert
|
:param min_capacity: The minimum number of the capacity of each expert
|
||||||
:param select_policy: The policy about tokens selection
|
:param select_policy: The policy about tokens selection
|
||||||
:param noisy_func: Noisy function used in logits
|
:param noisy_func: Noisy function used in logits
|
||||||
@ -66,7 +65,7 @@ class Top1Router(nn.Module):
|
|||||||
assert capacity > 0
|
assert capacity > 0
|
||||||
return capacity
|
return capacity
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
|
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||||
|
|
||||||
if self.noisy_func is not None and self.training:
|
if self.noisy_func is not None and self.training:
|
||||||
inputs = self.noisy_func(inputs)
|
inputs = self.noisy_func(inputs)
|
||||||
@ -82,10 +81,10 @@ class Top1Router(nn.Module):
|
|||||||
me = torch.mean(logits, dim=0)
|
me = torch.mean(logits, dim=0)
|
||||||
ce = torch.mean(mask.float(), dim=0)
|
ce = torch.mean(mask.float(), dim=0)
|
||||||
l_aux = num_experts * torch.sum(me * ce)
|
l_aux = num_experts * torch.sum(me * ce)
|
||||||
moe_env.add_loss(l_aux)
|
MOE_CONTEXT.add_loss(l_aux)
|
||||||
elif not self.drop_tks:
|
elif not self.drop_tks:
|
||||||
max_num = torch.max(torch.sum(mask, dim=0))
|
max_num = torch.max(torch.sum(mask, dim=0))
|
||||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||||
capacity = max_num.item()
|
capacity = max_num.item()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -103,7 +102,7 @@ class Top1Router(nn.Module):
|
|||||||
|
|
||||||
ranks = torch.sum(mask * ranks, dim=-1)
|
ranks = torch.sum(mask * ranks, dim=-1)
|
||||||
|
|
||||||
if cuda_mode:
|
if use_kernel:
|
||||||
mask = torch.sum(mask, dim=-1)
|
mask = torch.sum(mask, dim=-1)
|
||||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||||
@ -120,8 +119,8 @@ class Top2Router(nn.Module):
|
|||||||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||||
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
||||||
|
|
||||||
:param capacity_factor_train: Capacity factor in routing of training
|
:param capacity_factor_train: Capacity factor in routing during training
|
||||||
:param capacity_factor_eval: Capacity factor in routing of evaluation
|
:param capacity_factor_eval: Capacity factor in routing during evaluation
|
||||||
:param min_capacity: The minimum number of the capacity of each expert
|
:param min_capacity: The minimum number of the capacity of each expert
|
||||||
:param noisy_func: Noisy function used in logits
|
:param noisy_func: Noisy function used in logits
|
||||||
:param drop_tks: Whether drops tokens in evaluation
|
:param drop_tks: Whether drops tokens in evaluation
|
||||||
@ -157,7 +156,7 @@ class Top2Router(nn.Module):
|
|||||||
assert capacity > 0
|
assert capacity > 0
|
||||||
return capacity
|
return capacity
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
|
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||||
# inputs: [s, h]
|
# inputs: [s, h]
|
||||||
if self.noisy_func is not None and self.training:
|
if self.noisy_func is not None and self.training:
|
||||||
inputs = self.noisy_func(inputs)
|
inputs = self.noisy_func(inputs)
|
||||||
@ -177,10 +176,10 @@ class Top2Router(nn.Module):
|
|||||||
me = torch.mean(logits, dim=0)
|
me = torch.mean(logits, dim=0)
|
||||||
ce = torch.mean(cmask.float(), dim=0)
|
ce = torch.mean(cmask.float(), dim=0)
|
||||||
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
||||||
moe_env.add_loss(l_aux)
|
MOE_CONTEXT.add_loss(l_aux)
|
||||||
elif not self.drop_tks:
|
elif not self.drop_tks:
|
||||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||||
capacity = max_num.item()
|
capacity = max_num.item()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -195,7 +194,7 @@ class Top2Router(nn.Module):
|
|||||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||||
|
|
||||||
if cuda_mode:
|
if use_kernel:
|
||||||
mask1 = torch.sum(mask1, dim=-1)
|
mask1 = torch.sum(mask1, dim=-1)
|
||||||
mask2 = torch.sum(mask2, dim=-1)
|
mask2 = torch.sum(mask2, dim=-1)
|
||||||
|
|
||||||
@ -241,34 +240,36 @@ class MoeLayer(nn.Module):
|
|||||||
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
|
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
|
||||||
self.router = router
|
self.router = router
|
||||||
self.experts = experts
|
self.experts = experts
|
||||||
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
|
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||||
|
self.ep_group = experts.dist_info.ep_group
|
||||||
|
self.ep_size = experts.dist_info.ep_size
|
||||||
|
self.num_local_experts = experts.num_local_experts
|
||||||
|
|
||||||
def a2a_process(self, dispatch_data: torch.Tensor):
|
def a2a_process(self, dispatch_data: torch.Tensor):
|
||||||
expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
||||||
|
|
||||||
input_shape = expert_input.shape
|
input_shape = expert_input.shape
|
||||||
|
|
||||||
expert_input = expert_input.reshape(moe_env.model_parallel_size,
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
||||||
self.num_experts // moe_env.model_parallel_size, -1, self.d_model)
|
|
||||||
|
|
||||||
expert_output = self.experts(expert_input)
|
expert_output = self.experts(expert_input)
|
||||||
expert_output = expert_output.reshape(input_shape)
|
expert_output = expert_output.reshape(input_shape)
|
||||||
|
|
||||||
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
|
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
||||||
return expert_output
|
return expert_output
|
||||||
|
|
||||||
def tp_process(self, dispatch_data: torch.Tensor):
|
def tp_process(self, dispatch_data: torch.Tensor):
|
||||||
expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
expert_in = AllGather.apply(dispatch_data, self.ep_group)
|
||||||
expert_out = self.experts(expert_in)
|
expert_out = self.experts(expert_in)
|
||||||
expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL)
|
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
||||||
return expert_out
|
return expert_out
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
tokens = inputs.reshape(-1, self.d_model)
|
tokens = inputs.reshape(-1, self.d_model)
|
||||||
gate_output = self.gate(tokens)
|
gate_output = self.gate(tokens)
|
||||||
router_res = self.router(gate_output, self.cuda_mode)
|
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||||
|
|
||||||
if self.cuda_mode:
|
if self.use_kernel:
|
||||||
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
|
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
|
||||||
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
||||||
else:
|
else:
|
||||||
@ -276,16 +277,16 @@ class MoeLayer(nn.Module):
|
|||||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||||
|
|
||||||
# dispatch_data [e, c, h]
|
# dispatch_data [e, c, h]
|
||||||
if self.experts.comm == "all_to_all":
|
if self.experts.comm_name == "all_to_all":
|
||||||
expert_output = self.a2a_process(dispatch_data)
|
expert_output = self.a2a_process(dispatch_data)
|
||||||
elif self.experts.comm == "all_gather":
|
elif self.experts.comm_name == "all_gather":
|
||||||
expert_output = self.tp_process(dispatch_data)
|
expert_output = self.tp_process(dispatch_data)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
||||||
"build function.")
|
"build function.")
|
||||||
# expert_output [e, c, h]
|
# expert_output [e, c, h]
|
||||||
|
|
||||||
if self.cuda_mode:
|
if self.use_kernel:
|
||||||
expert_output = expert_output.reshape(-1, self.d_model)
|
expert_output = expert_output.reshape(-1, self.d_model)
|
||||||
ans = MoeCombine.apply(expert_output, *router_res)
|
ans = MoeCombine.apply(expert_output, *router_res)
|
||||||
else:
|
else:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.global_variables import moe_env
|
from colossalai.core import MOE_CONTEXT
|
||||||
from .experts import FFNExperts, TPExperts
|
from .experts import FFNExperts, TPExperts
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class UniformNoiseGenerator:
|
|||||||
:type eps: float
|
:type eps: float
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, eps: float):
|
def __init__(self, eps: float = 1e-2):
|
||||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
|
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
|
||||||
high=torch.tensor(1.0 + eps,
|
high=torch.tensor(1.0 + eps,
|
||||||
device=get_current_device())).rsample
|
device=get_current_device())).rsample
|
||||||
@ -55,10 +55,10 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
|
|||||||
|
|
||||||
|
|
||||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||||
moe_mp_size = moe_env.model_parallel_size
|
mep_size = MOE_CONTEXT.max_ep_size
|
||||||
if num_experts % moe_mp_size == 0:
|
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||||
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||||
elif d_ff % moe_mp_size == 0:
|
elif d_ff % mep_size == 0:
|
||||||
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.")
|
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.registry import LOSSES
|
from colossalai.registry import LOSSES
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
from colossalai.global_variables import moe_env
|
from colossalai.core import MOE_CONTEXT
|
||||||
|
|
||||||
|
|
||||||
@LOSSES.register_module
|
@LOSSES.register_module
|
||||||
@ -14,6 +14,7 @@ class MoeCrossEntropyLoss(_Loss):
|
|||||||
|
|
||||||
:type aux_weight: float, optional
|
:type aux_weight: float, optional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
|
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
|
self.loss = nn.CrossEntropyLoss(*args, **kwargs)
|
||||||
@ -21,7 +22,7 @@ class MoeCrossEntropyLoss(_Loss):
|
|||||||
|
|
||||||
def forward(self, *args):
|
def forward(self, *args):
|
||||||
main_loss = self.loss(*args)
|
main_loss = self.loss(*args)
|
||||||
aux_loss = moe_env.get_loss()
|
aux_loss = MOE_CONTEXT.get_loss()
|
||||||
return main_loss + self.aux_weight * aux_loss
|
return main_loss + self.aux_weight * aux_loss
|
||||||
|
|
||||||
|
|
||||||
@ -37,6 +38,7 @@ class MoeLoss(_Loss):
|
|||||||
:type aux_weight: float
|
:type aux_weight: float
|
||||||
:type loss_fn: Callable
|
:type loss_fn: Callable
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
|
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.loss_fn = loss_fn(*args, **kwargs)
|
self.loss_fn = loss_fn(*args, **kwargs)
|
||||||
@ -44,5 +46,5 @@ class MoeLoss(_Loss):
|
|||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
main_loss = self.loss_fn(*args, **kwargs)
|
main_loss = self.loss_fn(*args, **kwargs)
|
||||||
aux_loss = moe_env.get_loss()
|
aux_loss = MOE_CONTEXT.get_loss()
|
||||||
return main_loss + self.aux_weight * aux_loss
|
return main_loss + self.aux_weight * aux_loss
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.core import global_context as gpc, moe_context as moe_env
|
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from .common import is_using_ddp
|
from .common import is_using_ddp
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@ -45,7 +45,7 @@ def sync_moe_model_param(model: nn.Module):
|
|||||||
|
|
||||||
for ep_size in param_dict:
|
for ep_size in param_dict:
|
||||||
# When ep_size = world_size, communication is not needed
|
# When ep_size = world_size, communication is not needed
|
||||||
if ep_size != 1 and ep_size != moe_env.world_size:
|
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
|
||||||
src_rank = dist.get_rank(moe_env.information[ep_size].ep_group)
|
src_rank = dist.get_rank(MOE_CONTEXT.information[ep_size].ep_group)
|
||||||
for param in param_dict[ep_size]:
|
for param in param_dict[ep_size]:
|
||||||
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
|
dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
|
||||||
|
Loading…
Reference in New Issue
Block a user