mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[MOE] add unitest for MOE experts layout, gradient handler and kernel (#469)
This commit is contained in:
@@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
||||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ class Widenet(nn.Module):
|
||||
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
|
||||
|
||||
def forward(self, x):
|
||||
moe_env.reset_loss()
|
||||
MOE_CONTEXT.reset_loss()
|
||||
x = self.widenet(x)
|
||||
x = torch.mean(x, dim=1)
|
||||
x = self.linear(x)
|
||||
@@ -201,7 +201,7 @@ class ViTMoE(nn.Module):
|
||||
self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)
|
||||
|
||||
def forward(self, x):
|
||||
moe_env.reset_loss()
|
||||
MOE_CONTEXT.reset_loss()
|
||||
x = self.vitmoe(x)
|
||||
x = torch.mean(x, dim=1)
|
||||
x = self.linear(x)
|
||||
|
Reference in New Issue
Block a user