[setup] support pre-build and jit-build of cuda kernels (#2374)

* [setup] support pre-build and jit-build of cuda kernels

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-01-06 20:50:26 +08:00
committed by GitHub
parent 12c8bf38d7
commit 40d376c566
36 changed files with 414 additions and 390 deletions

View File

@@ -6,13 +6,32 @@ from torch import Tensor
from torch.distributed import ProcessGroup
COL_MOE_KERNEL_FLAG = False
from colossalai.kernel import moe
try:
from colossalai._C import moe
except:
moe = None
def build_moe_if_not_prebuilt():
# load moe kernel during runtime if not pre-built
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
if ctx is not None:
ctx.comm_grp = group
@@ -85,6 +104,9 @@ class MoeDispatch(torch.autograd.Function):
s = tokens.size(0)
h = tokens.size(1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
@@ -112,6 +134,9 @@ class MoeCombine(torch.autograd.Function):
c = ec // e
h = expert_tokens.size(-1)
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
@@ -143,6 +168,8 @@ def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
return moe.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1