mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,6 +5,17 @@ from .routers import MoeRouter, Top1Router, Top2Router
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
||||
|
||||
__all__ = [
|
||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model'
|
||||
"Experts",
|
||||
"FFNExperts",
|
||||
"TPExperts",
|
||||
"Top1Router",
|
||||
"Top2Router",
|
||||
"MoeLayer",
|
||||
"NormalNoiseGenerator",
|
||||
"UniformNoiseGenerator",
|
||||
"build_ffn_experts",
|
||||
"MoeModule",
|
||||
"MoeRouter",
|
||||
"save_moe_model",
|
||||
"load_moe_model",
|
||||
]
|
||||
|
@@ -18,18 +18,18 @@ def build_moe_if_not_prebuilt():
|
||||
global moe
|
||||
if moe is None:
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
|
||||
moe = MOEBuilder().load()
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
|
||||
global moe
|
||||
|
||||
if moe is None:
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
|
||||
moe = MOEBuilder().load()
|
||||
|
||||
if ctx is not None:
|
||||
@@ -51,7 +51,6 @@ class AllGather(torch.autograd.Function):
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
@@ -98,7 +97,6 @@ class AllToAll(torch.autograd.Function):
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||
s = tokens.size(0)
|
||||
@@ -124,7 +122,6 @@ class MoeDispatch(torch.autograd.Function):
|
||||
|
||||
|
||||
class MoeCombine(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||
assert logits.dtype == torch.float32
|
||||
@@ -137,7 +134,7 @@ class MoeCombine(torch.autograd.Function):
|
||||
# load moe kernel during runtime if not pre-built
|
||||
build_moe_if_not_prebuilt()
|
||||
|
||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||
fp16_flag = expert_tokens.dtype == torch.float16
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||
@@ -155,8 +152,7 @@ class MoeCombine(torch.autograd.Function):
|
||||
def backward(ctx, tokens_grad):
|
||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||
|
||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||
else tokens_grad
|
||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad
|
||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
|
||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||
|
@@ -16,7 +16,7 @@ def load_moe_model(model: nn.Module, load_path: str):
|
||||
state_dict = torch.load(load_path)
|
||||
|
||||
for prefix, module in model.named_modules():
|
||||
if prefix.endswith('.moe_layer.experts'):
|
||||
if prefix.endswith(".moe_layer.experts"):
|
||||
# this module should be an Experts instance
|
||||
assert isinstance(module, MoeExperts)
|
||||
|
||||
@@ -25,16 +25,16 @@ def load_moe_model(model: nn.Module, load_path: str):
|
||||
for i in range(num_local):
|
||||
expert_id = ep_rank * num_local + i
|
||||
for name, _ in module.experts[i].named_parameters():
|
||||
cur_key = f'{prefix}.experts.{i}.{name}'
|
||||
param_key = f'{prefix}.experts.{expert_id}.{name}'
|
||||
cur_key = f"{prefix}.experts.{i}.{name}"
|
||||
param_key = f"{prefix}.experts.{expert_id}.{name}"
|
||||
load_param = state_dict[param_key]
|
||||
state_dict[cur_key] = load_param
|
||||
|
||||
for name, _ in module.experts[0].named_parameters():
|
||||
pop_pre = f'{prefix}.experts.'
|
||||
pop_suf = f'.{name}'
|
||||
pop_pre = f"{prefix}.experts."
|
||||
pop_suf = f".{name}"
|
||||
for i in range(num_local, module.num_total_experts):
|
||||
pop_key = f'{pop_pre}{i}{pop_suf}'
|
||||
pop_key = f"{pop_pre}{i}{pop_suf}"
|
||||
state_dict.pop(pop_key)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
@@ -20,8 +20,10 @@ class MoeExperts(nn.Module):
|
||||
|
||||
def __init__(self, comm_name: str, num_experts: int):
|
||||
super().__init__()
|
||||
assert comm_name in {"all_to_all", "all_gather"}, \
|
||||
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||
assert comm_name in {
|
||||
"all_to_all",
|
||||
"all_gather",
|
||||
}, "This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||
self.comm_name = comm_name
|
||||
self.num_total_experts = num_experts
|
||||
# Get the configuration of experts' deployment and parallel information from moe context
|
||||
@@ -50,7 +52,7 @@ class Experts(MoeExperts):
|
||||
# Attach parallel information for all parameters in Experts
|
||||
for exp in self.experts:
|
||||
for param in exp.parameters():
|
||||
param.__setattr__('moe_info', self.dist_info)
|
||||
param.__setattr__("moe_info", self.dist_info)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
# Split inputs for each expert
|
||||
@@ -65,7 +67,7 @@ class Experts(MoeExperts):
|
||||
output = torch.cat(expert_output, dim=1).contiguous()
|
||||
return output
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
assert keep_vars == False, "Only support keep_vars=False now"
|
||||
dp_rank = dist.get_rank(self.dist_info.dp_group)
|
||||
ep_rank = dist.get_rank(self.dist_info.ep_group)
|
||||
@@ -79,11 +81,11 @@ class Experts(MoeExperts):
|
||||
example_submodule = subm
|
||||
|
||||
if dp_rank == 0:
|
||||
local_prefix = prefix + 'experts.'
|
||||
local_prefix = prefix + "experts."
|
||||
buffer_module = deepcopy(example_submodule)
|
||||
for i in range(self.num_total_experts):
|
||||
source_rank = i // self.num_local_experts
|
||||
current_prefix = local_prefix + str(i) + '.'
|
||||
current_prefix = local_prefix + str(i) + "."
|
||||
comm_module = submodule_dict.get(i, buffer_module)
|
||||
for name, param in comm_module.named_parameters():
|
||||
dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
|
||||
@@ -94,8 +96,7 @@ class Experts(MoeExperts):
|
||||
|
||||
|
||||
class FFNExperts(MoeExperts):
|
||||
"""Use torch.bmm to speed up for multiple experts.
|
||||
"""
|
||||
"""Use torch.bmm to speed up for multiple experts."""
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
@@ -119,10 +120,9 @@ class FFNExperts(MoeExperts):
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
for param in self.parameters():
|
||||
param.__setattr__('moe_info', self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, el, c, h]
|
||||
param.__setattr__("moe_info", self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, el, c, h]
|
||||
el = inputs.size(1)
|
||||
h = inputs.size(-1)
|
||||
|
||||
@@ -137,7 +137,7 @@ class FFNExperts(MoeExperts):
|
||||
|
||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
outputs = self.drop(out_model) # outputs [el, gc, h]
|
||||
outputs = self.drop(out_model) # outputs [el, gc, h]
|
||||
|
||||
outputs = outputs.reshape(inshape)
|
||||
outputs = outputs.transpose(0, 1).contiguous()
|
||||
@@ -153,8 +153,7 @@ class TPExperts(MoeExperts):
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
|
||||
|
||||
assert d_ff % MOE_CONTEXT.max_ep_size == 0, \
|
||||
"d_ff should be divide by maximum expert parallel size"
|
||||
assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size"
|
||||
|
||||
p_ff = d_ff // MOE_CONTEXT.max_ep_size
|
||||
|
||||
@@ -177,12 +176,11 @@ class TPExperts(MoeExperts):
|
||||
self.act = nn.GELU() if activation is None else activation
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
self.w1.__setattr__('moe_info', self.dist_info)
|
||||
self.w2.__setattr__('moe_info', self.dist_info)
|
||||
self.b1.__setattr__('moe_info', self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, e, c, h]
|
||||
self.w1.__setattr__("moe_info", self.dist_info)
|
||||
self.w2.__setattr__("moe_info", self.dist_info)
|
||||
self.b1.__setattr__("moe_info", self.dist_info)
|
||||
|
||||
def forward(self, inputs): # inputs [g, e, c, h]
|
||||
e = inputs.size(1)
|
||||
h = inputs.size(-1)
|
||||
|
||||
@@ -196,8 +194,8 @@ class TPExperts(MoeExperts):
|
||||
out_inter = self.drop(out_act)
|
||||
|
||||
out_model = torch.baddbmm(self.b2, out_inter, self.w2)
|
||||
outputs = self.drop(out_model) # outputs [e, gc, h]
|
||||
outputs = self.drop(out_model) # outputs [e, gc, h]
|
||||
|
||||
outputs = outputs.reshape(inshape)
|
||||
outputs = outputs.transpose(0, 1).contiguous()
|
||||
return outputs # outputs [g, e, c, h]
|
||||
return outputs # outputs [g, e, c, h]
|
||||
|
@@ -89,8 +89,9 @@ class MoeLayer(nn.Module):
|
||||
elif self.experts.comm_name == "all_gather":
|
||||
expert_output = self.tp_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
||||
"build function.")
|
||||
raise NotImplementedError(
|
||||
"This kind of communication has not been implemented yet.\n Please use Experts " "build function."
|
||||
)
|
||||
# expert_output [e, c, h]
|
||||
if self.use_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.d_model)
|
||||
@@ -135,27 +136,29 @@ class MoeModule(nn.Module):
|
||||
https://arxiv.org/abs/2201.05596
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim_model: int,
|
||||
num_experts: int,
|
||||
top_k: int = 1,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_policy: Optional[str] = None,
|
||||
drop_tks: bool = True,
|
||||
use_residual: bool = False,
|
||||
residual_instance: Optional[nn.Module] = None,
|
||||
expert_instance: Optional[MoeExperts] = None,
|
||||
expert_cls: Optional[Type[nn.Module]] = None,
|
||||
**expert_args):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
num_experts: int,
|
||||
top_k: int = 1,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_policy: Optional[str] = None,
|
||||
drop_tks: bool = True,
|
||||
use_residual: bool = False,
|
||||
residual_instance: Optional[nn.Module] = None,
|
||||
expert_instance: Optional[MoeExperts] = None,
|
||||
expert_cls: Optional[Type[nn.Module]] = None,
|
||||
**expert_args,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
noisy_func = None
|
||||
if noisy_policy is not None:
|
||||
if noisy_policy == 'Jitter':
|
||||
if noisy_policy == "Jitter":
|
||||
noisy_func = UniformNoiseGenerator()
|
||||
elif noisy_policy == 'Gaussian':
|
||||
elif noisy_policy == "Gaussian":
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported input noisy policy")
|
||||
@@ -167,18 +170,19 @@ class MoeModule(nn.Module):
|
||||
else:
|
||||
raise NotImplementedError("top_k > 2 is not supported yet")
|
||||
|
||||
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
self.moe_router = moe_router_cls(
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.use_residual = use_residual
|
||||
if use_residual:
|
||||
if residual_instance is not None:
|
||||
self.residual_module = residual_instance
|
||||
else:
|
||||
assert expert_cls is not None, \
|
||||
"Expert class can't be None when residual instance is not given"
|
||||
assert expert_cls is not None, "Expert class can't be None when residual instance is not given"
|
||||
self.residual_module = expert_cls(**expert_args)
|
||||
|
||||
with no_shard_zero_context():
|
||||
@@ -187,14 +191,12 @@ class MoeModule(nn.Module):
|
||||
if expert_instance is not None:
|
||||
my_experts = expert_instance
|
||||
else:
|
||||
assert expert_cls is not None, \
|
||||
"Expert class can't be None when experts instance is not given"
|
||||
assert expert_cls is not None, "Expert class can't be None when experts instance is not given"
|
||||
my_experts = Experts(expert_cls, num_experts, **expert_args)
|
||||
|
||||
self.moe_layer = MoeLayer(dim_model=dim_model,
|
||||
num_experts=num_experts,
|
||||
router=self.moe_router,
|
||||
experts=my_experts)
|
||||
self.moe_layer = MoeLayer(
|
||||
dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
moe_output, l_aux = self.moe_layer(inputs)
|
||||
|
@@ -1,226 +1,235 @@
|
||||
import math
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.context import MOE_CONTEXT
|
||||
from colossalai.nn.layer.moe._operation import moe_cumsum
|
||||
from typing import Callable, Optional
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
"""Base class for all MoE routers.
|
||||
Args:
|
||||
k_value (int): The value of top_k.
|
||||
capacity_factor_train (float): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
||||
min_capacity (int): The minimum number of the capacity of each expert.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
self._routing_loss = None
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
|
||||
assert self._routing_loss is None
|
||||
self._routing_loss = aux_loss
|
||||
|
||||
def pop_routing_loss(self) -> torch.Tensor:
|
||||
assert self._routing_loss is not None
|
||||
reservation = self._routing_loss
|
||||
self._routing_loss = None
|
||||
return reservation
|
||||
|
||||
|
||||
class Top1Router(MoeRouter):
|
||||
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More detailed function can be found in the paper about Switch Transformer
|
||||
of Google.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce)
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * logits.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More detailed function can be found in the paper about ViT-MoE.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
# inputs: [s, h]
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(logits, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = (mask1 + mask2) # loss: [s, e]
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(cmask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
weight1 = mask1 * logits.type_as(inputs)
|
||||
weight2 = mask2 * logits.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
|
||||
return cb_weight, sec_mask
|
||||
import math
|
||||
from abc import ABC
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.nn.layer.moe._operation import moe_cumsum
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
"""Base class for all MoE routers.
|
||||
Args:
|
||||
k_value (int): The value of top_k.
|
||||
capacity_factor_train (float): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float): Capacity factor in routing of evaluation.
|
||||
min_capacity (int): The minimum number of the capacity of each expert.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
self.capacity_factor_eval = capacity_factor_eval
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_func = noisy_func
|
||||
self.drop_tks = drop_tks
|
||||
self._routing_loss = None
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
|
||||
assert self._routing_loss is None
|
||||
self._routing_loss = aux_loss
|
||||
|
||||
def pop_routing_loss(self) -> torch.Tensor:
|
||||
assert self._routing_loss is not None
|
||||
reservation = self._routing_loss
|
||||
self._routing_loss = None
|
||||
return reservation
|
||||
|
||||
|
||||
class Top1Router(MoeRouter):
|
||||
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More detailed function can be found in the paper about Switch Transformer
|
||||
of Google.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
||||
select_policy (str, optional): The policy about tokens selection.
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
|
||||
).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(inputs, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce)
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
if self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * logits.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask
|
||||
|
||||
|
||||
class Top2Router(MoeRouter):
|
||||
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
||||
for routing usage. More detailed function can be found in the paper about ViT-MoE.
|
||||
Args:
|
||||
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Callable = None,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
||||
# inputs: [s, h]
|
||||
if self.noisy_func is not None and self.training:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
assert inputs.dtype == torch.float
|
||||
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
top1_idx = torch.argmax(logits, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = mask1 + mask2 # loss: [s, e]
|
||||
|
||||
# caculate the auxiliary loss
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(cmask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
||||
self.set_routing_loss(l_aux)
|
||||
|
||||
if not self.training and not self.drop_tks:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
||||
capacity = max_num.item()
|
||||
|
||||
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
if use_kernel:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
weight1 = mask1 * logits.type_as(inputs)
|
||||
weight2 = mask2 * logits.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
|
||||
return cb_weight, sec_mask
|
||||
|
@@ -1,68 +1,71 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
||||
def half(self, memory_format=None):
|
||||
return self.data.clone()
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
|
||||
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
|
||||
`E = the number of experts`.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
class UniformNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
copied from mesh tensorflow:
|
||||
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
|
||||
Makes models more resilient to rounding errors introduced by bfloat16.
|
||||
This seems particularly important for logits.
|
||||
|
||||
Args:
|
||||
eps (float, optional): Epsilon in generator, defaults 1e-2.
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-2):
|
||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()),
|
||||
high=torch.tensor(1.0 + eps,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.uniform(inputs.shape)
|
||||
return inputs * noisy
|
||||
|
||||
|
||||
def autocast_softmax(logit: torch.Tensor, dim: int):
|
||||
if logit.dtype != torch.float32:
|
||||
logit = logit.float()
|
||||
return F.softmax(logit, dim=dim)
|
||||
|
||||
|
||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
mep_size = MOE_CONTEXT.max_ep_size
|
||||
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
elif d_ff % mep_size == 0:
|
||||
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
else:
|
||||
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
def half(self, memory_format=None):
|
||||
return self.data.clone()
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
|
||||
All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
|
||||
`E = the number of experts`.
|
||||
|
||||
Args:
|
||||
num_experts (int): The number of experts.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
class UniformNoiseGenerator:
|
||||
"""Generates a random noisy mask for logits tensor.
|
||||
copied from mesh tensorflow:
|
||||
Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
|
||||
Makes models more resilient to rounding errors introduced by bfloat16.
|
||||
This seems particularly important for logits.
|
||||
|
||||
Args:
|
||||
eps (float, optional): Epsilon in generator, defaults 1e-2.
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-2):
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(1.0 - eps, device=get_current_device()),
|
||||
high=torch.tensor(1.0 + eps, device=get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.uniform(inputs.shape)
|
||||
return inputs * noisy
|
||||
|
||||
|
||||
def autocast_softmax(logit: torch.Tensor, dim: int):
|
||||
if logit.dtype != torch.float32:
|
||||
logit = logit.float()
|
||||
return F.softmax(logit, dim=dim)
|
||||
|
||||
|
||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
mep_size = MOE_CONTEXT.max_ep_size
|
||||
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
|
||||
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
elif d_ff % mep_size == 0:
|
||||
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
else:
|
||||
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
|
||||
|
@@ -8,7 +8,6 @@ def divide(numerator, denominator):
|
||||
Returns:
|
||||
int: the result of exact division.
|
||||
"""
|
||||
assert denominator != 0, 'denominator can not be zero'
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
assert denominator != 0, "denominator can not be zero"
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
Reference in New Issue
Block a user