mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[builder] MOE builder (#2277)
This commit is contained in:
@@ -6,12 +6,7 @@ from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
COL_MOE_KERNEL_FLAG = False
|
||||
try:
|
||||
import colossalai._C.moe
|
||||
|
||||
COL_MOE_KERNEL_FLAG = True
|
||||
except ImportError:
|
||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||
from colossalai.kernel import moe
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
@@ -90,7 +85,7 @@ class MoeDispatch(torch.autograd.Function):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
|
||||
expert_input = colossalai._C.moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
ctx.s = s
|
||||
@@ -102,7 +97,7 @@ class MoeDispatch(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
d_tokens = colossalai._C.moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
@@ -119,7 +114,7 @@ class MoeCombine(torch.autograd.Function):
|
||||
|
||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = colossalai._C.moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
@@ -138,8 +133,7 @@ class MoeCombine(torch.autograd.Function):
|
||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||
else tokens_grad
|
||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||
d_expert, d_logits = colossalai._C.moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
||||
mask, dest_idx)
|
||||
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
|
||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
@@ -149,6 +143,6 @@ def moe_cumsum(inputs: Tensor):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and COL_MOE_KERNEL_FLAG:
|
||||
return colossalai._C.moe.cumsum_sub_one(inputs)
|
||||
return moe.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
Reference in New Issue
Block a user