[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)

* refactor compilation mechanism and unified multi hw

* fix file path bug

* add init.py to make pybind a module to avoid relative path error caused by softlink

* delete duplicated micros

* fix micros bug in gcc
This commit is contained in:
傅剑寒
2024-04-24 14:17:54 +08:00
committed by GitHub
parent 04863a9b14
commit 279300dc5f
64 changed files with 345 additions and 310 deletions

View File

@@ -0,0 +1,3 @@
from .moe_cuda import MoeCudaExtension
__all__ = ["MoeCudaExtension"]

View File

@@ -0,0 +1,97 @@
#include <torch/extension.h>
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx);
}
torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx);
}
torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask,
dest_idx);
}
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens,
logits, mask, dest_idx);
}
torch::Tensor moe_cumsum(torch::Tensor mask) {
CHECK_INPUT(mask);
return cumsum_sub_one_in_dim0(mask);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0");
m.def("dispatch_forward", &moe_dispatch_forward,
"Forward operation in MoE dispatch function");
m.def("dispatch_backward", &moe_dispatch_backward,
"Backward operation in MoE dispatch function");
m.def("combine_forward", &moe_combine_forward,
"Combine operation in MoE combine function");
m.def("combine_backward", &moe_combine_backward,
"Combine operation in MoE combine function");
}

View File

@@ -0,0 +1,27 @@
from ...cuda_extension import _CudaExtension
from ...utils import append_nvcc_threads, get_cuda_cc_flag
class MoeCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="moe_cuda")
def sources_files(self):
ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/moe_kernel.cu"]] + [
self.pybind_abs_path("moe/moe.cpp")
]
return ret
def cxx_flags(self):
return ["-O3"] + self.version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]
extra_cuda_flags.extend(get_cuda_cc_flag())
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
return append_nvcc_threads(ret)