mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 12:29:09 +00:00
[NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#642)
This commit is contained in:
parent
10afec728f
commit
5ab9a71299
@ -1,56 +1,47 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
||||||
torch::Tensor moe_dispatch_cuda_forward(
|
|
||||||
int s, int ec, int h,
|
|
||||||
torch::Tensor batch_tokens,
|
torch::Tensor batch_tokens,
|
||||||
torch::Tensor mask,
|
torch::Tensor mask,
|
||||||
torch::Tensor dest_idx);
|
torch::Tensor dest_idx);
|
||||||
|
|
||||||
torch::Tensor moe_dispatch_cuda_backward(
|
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
||||||
int s, int ec, int h,
|
|
||||||
torch::Tensor expert_grad,
|
torch::Tensor expert_grad,
|
||||||
torch::Tensor mask,
|
torch::Tensor mask,
|
||||||
torch::Tensor dest_idx);
|
torch::Tensor dest_idx);
|
||||||
|
|
||||||
torch::Tensor moe_combine_cuda_forward(
|
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
||||||
int s, int e, int c, int h,
|
|
||||||
torch::Tensor expert_tokens,
|
torch::Tensor expert_tokens,
|
||||||
torch::Tensor logits,
|
torch::Tensor logits, torch::Tensor mask,
|
||||||
torch::Tensor mask,
|
|
||||||
torch::Tensor dest_idx);
|
torch::Tensor dest_idx);
|
||||||
|
|
||||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
std::vector<torch::Tensor>
|
||||||
int s, int e, int c, int h,
|
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||||
torch::Tensor tokens_grad,
|
torch::Tensor expert_tokens, torch::Tensor logits,
|
||||||
torch::Tensor expert_tokens,
|
torch::Tensor mask, torch::Tensor dest_idx);
|
||||||
torch::Tensor logits,
|
|
||||||
torch::Tensor mask,
|
|
||||||
torch::Tensor dest_idx);
|
|
||||||
|
|
||||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
||||||
|
|
||||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
#define CHECK_CUDA(x) \
|
||||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
#define CHECK_CONTIGUOUS(x) \
|
||||||
|
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
torch::Tensor moe_dispatch_forward(
|
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
|
||||||
int s, int ec, int h,
|
|
||||||
torch::Tensor batch_tokens,
|
torch::Tensor batch_tokens,
|
||||||
torch::Tensor mask,
|
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||||
torch::Tensor dest_idx) {
|
|
||||||
|
|
||||||
CHECK_INPUT(batch_tokens);
|
CHECK_INPUT(batch_tokens);
|
||||||
CHECK_CUDA(mask);
|
CHECK_CUDA(mask);
|
||||||
CHECK_CUDA(dest_idx);
|
CHECK_CUDA(dest_idx);
|
||||||
|
|
||||||
return moe_dispatch_cuda_forward(
|
return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx);
|
||||||
s, ec, h,
|
|
||||||
batch_tokens, mask, dest_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor moe_dispatch_backward(
|
torch::Tensor moe_dispatch_backward(int s, int ec, int h,
|
||||||
int s, int ec, int h,
|
|
||||||
torch::Tensor expert_grad,
|
torch::Tensor expert_grad,
|
||||||
torch::Tensor mask,
|
torch::Tensor mask,
|
||||||
torch::Tensor dest_idx) {
|
torch::Tensor dest_idx) {
|
||||||
@ -59,16 +50,12 @@ torch::Tensor moe_dispatch_backward(
|
|||||||
CHECK_CUDA(mask);
|
CHECK_CUDA(mask);
|
||||||
CHECK_CUDA(dest_idx);
|
CHECK_CUDA(dest_idx);
|
||||||
|
|
||||||
return moe_dispatch_cuda_backward(
|
return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx);
|
||||||
s, ec, h,
|
|
||||||
expert_grad, mask, dest_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor moe_combine_forward(
|
torch::Tensor moe_combine_forward(int s, int e, int c, int h,
|
||||||
int s, int e, int c, int h,
|
|
||||||
torch::Tensor expert_tokens,
|
torch::Tensor expert_tokens,
|
||||||
torch::Tensor logits,
|
torch::Tensor logits, torch::Tensor mask,
|
||||||
torch::Tensor mask,
|
|
||||||
torch::Tensor dest_idx) {
|
torch::Tensor dest_idx) {
|
||||||
|
|
||||||
CHECK_INPUT(expert_tokens);
|
CHECK_INPUT(expert_tokens);
|
||||||
@ -76,27 +63,22 @@ torch::Tensor moe_combine_forward(
|
|||||||
CHECK_CUDA(mask);
|
CHECK_CUDA(mask);
|
||||||
CHECK_CUDA(dest_idx);
|
CHECK_CUDA(dest_idx);
|
||||||
|
|
||||||
return moe_combine_cuda_forward(
|
return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask,
|
||||||
s, e, c, h,
|
dest_idx);
|
||||||
expert_tokens, logits, mask, dest_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::Tensor> moe_combine_backward(
|
std::vector<torch::Tensor>
|
||||||
int s, int e, int c, int h,
|
moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||||
torch::Tensor tokens_grad,
|
torch::Tensor expert_tokens, torch::Tensor logits,
|
||||||
torch::Tensor expert_tokens,
|
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||||
torch::Tensor logits,
|
|
||||||
torch::Tensor mask,
|
|
||||||
torch::Tensor dest_idx) {
|
|
||||||
|
|
||||||
CHECK_INPUT(tokens_grad);
|
CHECK_INPUT(tokens_grad);
|
||||||
CHECK_INPUT(logits);
|
CHECK_INPUT(logits);
|
||||||
CHECK_CUDA(mask);
|
CHECK_CUDA(mask);
|
||||||
CHECK_CUDA(dest_idx);
|
CHECK_CUDA(dest_idx);
|
||||||
|
|
||||||
return moe_combine_cuda_backward(
|
return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens,
|
||||||
s, e, c, h,
|
logits, mask, dest_idx);
|
||||||
tokens_grad, expert_tokens, logits, mask, dest_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor moe_cumsum(torch::Tensor mask) {
|
torch::Tensor moe_cumsum(torch::Tensor mask) {
|
||||||
@ -105,8 +87,7 @@ torch::Tensor moe_cumsum(torch::Tensor mask) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("cumsum_sub_one", &moe_cumsum,
|
m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0");
|
||||||
"Fast cumsum operation in dim0");
|
|
||||||
m.def("dispatch_forward", &moe_dispatch_forward,
|
m.def("dispatch_forward", &moe_dispatch_forward,
|
||||||
"Forward operation in MoE dispatch function");
|
"Forward operation in MoE dispatch function");
|
||||||
m.def("dispatch_backward", &moe_dispatch_backward,
|
m.def("dispatch_backward", &moe_dispatch_backward,
|
||||||
|
Loading…
Reference in New Issue
Block a user