mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[MOE] support PR-MOE (#488)
This commit is contained in:
@@ -4,11 +4,12 @@ import torch.nn as nn
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
||||
WrappedDropout as Dropout, WrappedDropPath as DropPath
|
||||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from typing import List
|
||||
|
||||
|
||||
class VanillaSelfAttention(nn.Module):
|
||||
@@ -146,7 +147,8 @@ class Widenet(nn.Module):
|
||||
class ViTMoE(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
num_experts: int or List[int],
|
||||
use_residual: bool = False,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
drop_tks: bool = True,
|
||||
@@ -164,29 +166,45 @@ class ViTMoE(nn.Module):
|
||||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
|
||||
assert depth % 2 == 0, "The number of layers should be even right now"
|
||||
|
||||
if isinstance(num_experts, list):
|
||||
assert len(num_experts) == depth // 2, \
|
||||
"The length of num_experts should equal to the number of MOE layers"
|
||||
num_experts_list = num_experts
|
||||
else:
|
||||
num_experts_list = [num_experts] * (depth // 2)
|
||||
|
||||
embedding = VanillaPatchEmbedding(img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_size=d_model)
|
||||
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
|
||||
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
router = Top2Router(capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
assert depth % 2 == 0
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
blocks = []
|
||||
for i in range(depth):
|
||||
sa = VanillaSelfAttention(**moe_sa_args(
|
||||
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
|
||||
ffn = VanillaFFN(**moe_mlp_args(
|
||||
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
|
||||
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
|
||||
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate))
|
||||
|
||||
if i % 2 == 0:
|
||||
ffn = VanillaFFN(**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
|
||||
else:
|
||||
num_experts = num_experts_list[i // 2]
|
||||
experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
|
||||
ffn = MoeModule(dim_model=d_model,
|
||||
num_experts=num_experts,
|
||||
top_k=1 if use_residual else 2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
noisy_policy='Jitter' if use_residual else 'Gaussian',
|
||||
drop_tks=drop_tks,
|
||||
use_residual=use_residual,
|
||||
expert_instance=experts,
|
||||
expert_cls=VanillaFFN,
|
||||
**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
|
||||
|
||||
layer = TransformerLayer(att=sa,
|
||||
ffn=ffn,
|
||||
norm1=nn.LayerNorm(d_model, eps=1e-6),
|
||||
|
Reference in New Issue
Block a user