added Multiply Jitter and capacity factor eval for MOE (#434)

This commit is contained in:
HELSON
2022-03-16 16:47:44 +08:00
committed by GitHub
parent b03b3ae99c
commit dbdc9a7783
3 changed files with 92 additions and 27 deletions

View File

@@ -84,7 +84,9 @@ class Widenet(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
@@ -109,7 +111,10 @@ class Widenet(nn.Module):
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule
@@ -142,7 +147,9 @@ class ViTMoE(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
@@ -164,8 +171,10 @@ class ViTMoE(nn.Module):
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor, noisy_func=noisy_func)
router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
assert depth % 2 == 0
# stochastic depth decay rule