diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 053de0ef6..662a907fc 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -21,17 +21,15 @@ class AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, - inputs: Tensor, - parallel_mode: ParallelMode) -> Tensor: + def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: if ctx is not None: ctx.parallel_mode = parallel_mode if not inputs.is_contiguous(): inputs = inputs.contiguous() - + if gpc.get_world_size(parallel_mode) == 1: + return inputs output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, - group=gpc.get_group(parallel_mode)) + dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode)) return output @staticmethod @@ -58,8 +56,7 @@ class MoeDispatch(torch.autograd.Function): @staticmethod def backward(ctx, output_grad): mask, dest_idx = ctx.saved_tensors - d_tokens = colossal_moe_cuda.dispatch_backward( - ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) return d_tokens, None, None, None @@ -76,9 +73,7 @@ class MoeCombine(torch.autograd.Function): fp16_flag = (expert_tokens.dtype == torch.float16) cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens - ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, - cb_input, logits, - mask, dest_idx) + ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) output = ctokens.to(torch.float16) if fp16_flag else ctokens ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) @@ -97,9 +92,8 @@ class MoeCombine(torch.autograd.Function): cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ else tokens_grad cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = colossal_moe_cuda.combine_backward( - ctx.s, ctx.e, ctx.c, ctx.h, - cb_grad, cb_input, logits, mask, dest_idx) + d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, + mask, dest_idx) d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert return d_expert, d_logits, None, None, None diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 4688c9ce7..ef9618c1a 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -62,10 +62,12 @@ class FFNExperts(nn.Module): s1 = math.sqrt(0.1 / d_model) s2 = math.sqrt(0.1 / d_ff) - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - nn.init.trunc_normal_(self.b2, std=s2) + + with seed(ParallelMode.MOE_MODEL): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + nn.init.trunc_normal_(self.b2, std=s2) self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate)