fix format (#583)

This commit is contained in:
binmakeswell 2022-03-31 18:10:03 +08:00
parent f08fc17f2b
commit 6582aedc94

View File

@ -1,66 +1,70 @@
#include <torch/extension.h> #include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cub/cub.cuh> #include <torch/extension.h>
#include "block_reduce.h"
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size]; int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) { for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack); BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row + idx, pack); BlockStore(ts_store).Store(dst_row + idx, pack);
} }
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size]; int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) { for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, pack); BlockLoad(ts_load).Load(dst_row + idx, pack);
BlockStore(ts_store).Store(src_row + idx, pack); BlockStore(ts_store).Store(src_row + idx, pack);
} }
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size]; int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) { for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack); BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row1 + idx, pack); BlockStore(ts_store).Store(dst_row1 + idx, pack);
@ -68,18 +72,19 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int
} }
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; int tps = threadIdx.x * pack_size;
@ -88,7 +93,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int
BlockLoad(ts_load).Load(dst_row1 + idx, pack1); BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
BlockLoad(ts_load).Load(dst_row2 + idx, pack2); BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
#pragma unroll #pragma unroll
for (int i = 0; i < pack_size; ++i) { for (int i = 0; i < pack_size; ++i) {
pack1[i] += pack2[i]; pack1[i] += pack2[i];
} }
@ -97,27 +102,27 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int
} }
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd( __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
T *src_row, T *dst_row, const int cols) {
const T weight, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size]; int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) { for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack); BlockLoad(ts_load).Load(src_row + idx, pack);
#pragma unroll #pragma unroll
for (int i = 0; i < pack_size; ++i) { for (int i = 0; i < pack_size; ++i) {
pack[i] *= weight; pack[i] *= weight;
} }
@ -126,20 +131,19 @@ __device__ void moe_cb_one_fwd(
} }
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd( __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *src_row, T *dst_row, T *tks_row, T *weight_grad, T *weight_grad, const T weight, const int cols) {
const T weight, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; int tps = threadIdx.x * pack_size;
@ -149,7 +153,7 @@ __device__ void moe_cb_one_bwd(
BlockLoad(ts_load).Load(dst_row + idx, grad); BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row + idx, tokens); BlockLoad(ts_load).Load(tks_row + idx, tokens);
#pragma unroll #pragma unroll
for (int i = 0; i < pack_size; ++i) { for (int i = 0; i < pack_size; ++i) {
thread_sum += grad[i] * tokens[i]; thread_sum += grad[i] * tokens[i];
grad[i] *= weight; grad[i] *= weight;
@ -164,20 +168,20 @@ __device__ void moe_cb_one_bwd(
*weight_grad = static_cast<T>(thread_sum); *weight_grad = static_cast<T>(thread_sum);
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd( __device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
T *src_row1, T *src_row2, T *dst_row, const T weight1, const T weight2,
const T weight1, const T weight2, const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; int tps = threadIdx.x * pack_size;
@ -186,7 +190,7 @@ __device__ void moe_cb_two_fwd(
BlockLoad(ts_load).Load(src_row1 + idx, pack1); BlockLoad(ts_load).Load(src_row1 + idx, pack1);
BlockLoad(ts_load).Load(src_row2 + idx, pack2); BlockLoad(ts_load).Load(src_row2 + idx, pack2);
#pragma unroll #pragma unroll
for (int i = 0; i < pack_size; ++i) { for (int i = 0; i < pack_size; ++i) {
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
} }
@ -195,33 +199,33 @@ __device__ void moe_cb_two_fwd(
} }
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_bwd( __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *src_row1, T *src_row2, T *dst_row, T *tks_row1, T *tks_row2, T *weight_grad1,
T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2, T *weight_grad2, const T weight1,
const T weight1, const T weight2, const int cols) { const T weight2, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens1[pack_size], tokens2[pack_size], T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size],
sgrad1[pack_size], sgrad2[pack_size]; sgrad2[pack_size];
float thread_sum[2] = {0, 0}; float thread_sum[2] = {0, 0};
for (int idx = 0; idx + tps < cols; idx += bpack_size) { for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad); BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
#pragma unroll #pragma unroll
for (int i = 0; i < pack_size; ++i) { for (int i = 0; i < pack_size; ++i) {
thread_sum[0] += grad[i] * tokens1[i]; thread_sum[0] += grad[i] * tokens1[i];
thread_sum[1] += grad[i] * tokens2[i]; thread_sum[1] += grad[i] * tokens2[i];
@ -239,142 +243,125 @@ __device__ void moe_cb_two_bwd(
*weight_grad1 = static_cast<T>(thread_sum[0]); *weight_grad1 = static_cast<T>(thread_sum[0]);
else if (threadIdx.x == 1) else if (threadIdx.x == 1)
*weight_grad2 = static_cast<T>(thread_sum[1]); *weight_grad2 = static_cast<T>(thread_sum[1]);
} }
// DISPATCH KERNELS -------------------------------- // DISPATCH KERNELS --------------------------------
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector( __device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
T *src_row, T *dst_row1, T *dst_row2, const int cols, const int cols, const int indicator1,
const int indicator1, const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>( moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
src_row, dst_row1, dst_row2, cols); cols);
else if (indicator1 != 0) else if (indicator1 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>( moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row1, cols);
src_row, dst_row1, cols);
else if (indicator2 != 0) else if (indicator2 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>( moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row2, cols);
src_row, dst_row2, cols);
else else
return; return;
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector( __device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
T *src_row, T *dst_row1, T *dst_row2, const int cols, const int cols, const int indicator1,
const int indicator1, const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>( moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
src_row, dst_row1, dst_row2, cols); cols);
else if (indicator1 != 0) else if (indicator1 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>( moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row1, cols);
src_row, dst_row1, cols);
else if (indicator2 != 0) else if (indicator2 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>( moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row2, cols);
src_row, dst_row2, cols);
else else
return; return;
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel( __global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
T *batch_tokens, T *expert_input, int *mask1, int *mask2, int *dest1,
int *mask1, int *mask2, int *dest2, const int h) {
int *dest1, int *dest2, const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>( moe_dpch_fwd_selector<T, block_size, pack_size>(
batch_tokens + (row * h), batch_tokens + (row * h), expert_input + (dest1[row] * h),
expert_input + (dest1[row] * h), expert_input + (dest2[row] * h), expert_input + (dest2[row] * h), h, mask1[row], indicator2);
h, mask1[row], indicator2);
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel( __global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
T *tokens_grad, T *expert_grad, int *mask2, int *dest1, int *dest2,
int *mask1, int *mask2, const int h) {
int *dest1, int *dest2, const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>( moe_dpch_bwd_selector<T, block_size, pack_size>(
tokens_grad + (row * h), tokens_grad + (row * h), expert_grad + (dest1[row] * h),
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), expert_grad + (dest2[row] * h), h, mask1[row], indicator2);
h, mask1[row], indicator2);
} }
// COMBINE KERNELS -------------------------------- // COMBINE KERNELS --------------------------------
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_fwd_selector( __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *src_row1, T *src_row2, T *dst_row, const int cols, const int cols, const T weight1,
const T weight1, const T weight2, const T weight2, const int indicator1,
const int indicator1, const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>( moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
src_row1, src_row2, dst_row, weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(
src_row1, dst_row, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(
src_row2, dst_row, weight2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_bwd_selector(
T *src_row1, T *src_row2, T *dst_row, const int cols,
T *tks_row1, T *tks_row2, T *wt_grad1, T *wt_grad2,
const T weight1, const T weight2,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(
src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1, wt_grad2,
weight1, weight2, cols); weight1, weight2, cols);
else if (indicator1 != 0) else if (indicator1 != 0)
moe_cb_one_bwd<T, block_size, pack_size>( moe_cb_one_fwd<T, block_size, pack_size>(src_row1, dst_row, weight1, cols);
src_row1, dst_row, tks_row1, wt_grad1, weight1, cols);
else if (indicator2 != 0) else if (indicator2 != 0)
moe_cb_one_bwd<T, block_size, pack_size>( moe_cb_one_fwd<T, block_size, pack_size>(src_row2, dst_row, weight2, cols);
src_row2, dst_row, tks_row2, wt_grad2, weight2, cols);
else else
return; return;
} }
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, T *tks_row1, T *tks_row2,
T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {
template<typename T, int block_size, int pack_size> if (indicator1 != 0 && indicator2 != 0)
__global__ void moe_cb_fwd_kernel( moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
T *expert_tokens, T *combine_tokens, T *logits, tks_row1, tks_row2, wt_grad1,
int *mask1, int *mask2, wt_grad2, weight1, weight2, cols);
int *dest1, int *dest2, else if (indicator1 != 0)
const int e, const int c, const int h) { moe_cb_one_bwd<T, block_size, pack_size>(src_row1, dst_row, tks_row1,
wt_grad1, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(src_row2, dst_row, tks_row2,
wt_grad2, weight2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c,
const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e); T *row_log = logits + (row * e);
moe_cb_fwd_selector<T, block_size, pack_size>( moe_cb_fwd_selector<T, block_size, pack_size>(
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
combine_tokens + (row * h), h, combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row],
row_log[eid1], row_log[eid2], indicator2);
mask1[row], indicator2);
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__global__ void moe_cb_bwd_kernel( __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *tokens_grad, T *expert_grad, T *tks, T *logits, T *logits_grad, int *mask1,
T *logits, T *logits_grad, int *mask2, int *dest1, int *dest2,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int e, const int c, const int h) { const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
@ -382,36 +369,33 @@ __global__ void moe_cb_bwd_kernel(
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
moe_cb_bwd_selector<T, block_size, pack_size>( moe_cb_bwd_selector<T, block_size, pack_size>(
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
tokens_grad + (row * h), h, tokens_grad + (row * h), h, tks + (dest1[row] * h),
tks + (dest1[row] * h), tks + (dest2[row] * h), tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1],
row_grad + eid1, row_grad + eid2, row_log[eid2], mask1[row], indicator2);
row_log[eid1], row_log[eid2],
mask1[row], indicator2);
} }
//CUMSUM KERNEL -------------------------------- // CUMSUM KERNEL --------------------------------
template<int block_size, int pack_size> template <int block_size, int pack_size>
__global__ void cumsum_kernel( __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
int *inputs, int *outputs, const int e) {
const int s, const int e) {
assert(s % pack_size == 0); assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size; constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
tps = tid * pack_size, last_sum = -1; __shared__ int temp[block_size + 1];
__shared__ int temp[block_size + 1]; int pack[pack_size]; int pack[pack_size];
for (int idx = 0; idx < s; idx += bpack_size) { for (int idx = 0; idx < s; idx += bpack_size) {
int offset = 1; int offset = 1;
if (idx + tps < s) { if (idx + tps < s) {
temp[tid] = inputs[tps * e + bid]; temp[tid] = inputs[tps * e + bid];
#pragma unroll #pragma unroll
for (int i = 1; i < pack_size; ++i) { for (int i = 1; i < pack_size; ++i) {
pack[i] = inputs[(tps + i) * e + bid]; pack[i] = inputs[(tps + i) * e + bid];
} }
#pragma unroll #pragma unroll
for (int i = 1; i < pack_size; ++i) { for (int i = 1; i < pack_size; ++i) {
temp[tid] += pack[i]; temp[tid] += pack[i];
} }
@ -435,8 +419,7 @@ __global__ void cumsum_kernel(
offset >>= 1; offset >>= 1;
__syncthreads(); __syncthreads();
if (tid < i) { if (tid < i) {
int j = offset * (2 * tid + 1) - 1, int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j];
k = j + offset, ts = temp[j];
temp[j] = temp[k]; temp[j] = temp[k];
temp[k] += ts; temp[k] += ts;
} }
@ -449,7 +432,7 @@ __global__ void cumsum_kernel(
if (idx + tps < s) { if (idx + tps < s) {
temp[tid + 1] += last_sum; temp[tid + 1] += last_sum;
#pragma unroll #pragma unroll
for (int i = pack_size - 1; i > 0; --i) { for (int i = pack_size - 1; i > 0; --i) {
outputs[(tps + i) * e + bid] = temp[tid + 1]; outputs[(tps + i) * e + bid] = temp[tid + 1];
temp[tid + 1] -= pack[i]; temp[tid + 1] -= pack[i];
@ -464,95 +447,103 @@ __global__ void cumsum_kernel(
} }
} }
//LAUNCH FUNCTIONS -------------------------------- // LAUNCH FUNCTIONS --------------------------------
template<typename T> template <typename T>
void moe_dpch_fwd_launch( void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
T *batch_tokens, T *expert_input, int *mask2, int *dest1, int *dest2, const int s,
int *mask1, int *mask2, const int h) {
int *dest1, int *dest2,
const int s, const int h) {
if (h < 256) if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 512) else if (h < 512)
moe_dpch_fwd_kernel<T, 32, 8><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); moe_dpch_fwd_kernel<T, 32, 8>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 1024) else if (h < 1024)
moe_dpch_fwd_kernel<T, 32, 16><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); moe_dpch_fwd_kernel<T, 32, 16>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 2048) else if (h < 2048)
moe_dpch_fwd_kernel<T, 64, 16><<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); moe_dpch_fwd_kernel<T, 64, 16>
<<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else else
moe_dpch_fwd_kernel<T, 128, 16><<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); moe_dpch_fwd_kernel<T, 128, 16>
<<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
} }
template<typename T> template <typename T>
void moe_dpch_bwd_launch( void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
T *tokens_grad, T *expert_grad, int *dest1, int *dest2, const int s, const int h) {
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int h) {
if (h < 256) if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 512) else if (h < 512)
moe_dpch_bwd_kernel<T, 32, 8><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); moe_dpch_bwd_kernel<T, 32, 8>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 1024) else if (h < 1024)
moe_dpch_bwd_kernel<T, 32, 16><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); moe_dpch_bwd_kernel<T, 32, 16>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 2048) else if (h < 2048)
moe_dpch_bwd_kernel<T, 64, 16><<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); moe_dpch_bwd_kernel<T, 64, 16>
<<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else else
moe_dpch_bwd_kernel<T, 128, 16><<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); moe_dpch_bwd_kernel<T, 128, 16>
<<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
} }
template<typename T> template <typename T>
void moe_cb_fwd_launch( void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
T *expert_tokens, T *combine_tokens, T *logits, int *mask1, int *mask2, int *dest1, int *dest2,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) { const int s, const int e, const int c, const int h) {
if (h < 256) if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>> moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 512) else if (h < 512)
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>> moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>(expert_tokens, combine_tokens,
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 1024) else if (h < 1024)
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>> moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>(expert_tokens, combine_tokens,
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 2048) else if (h < 2048)
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>> moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>(expert_tokens, combine_tokens,
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); logits, mask1, mask2, dest1, dest2,
e, c, h);
else else
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>> moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>(expert_tokens, combine_tokens,
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); logits, mask1, mask2, dest1,
dest2, e, c, h);
} }
template<typename T> template <typename T>
void moe_cb_bwd_launch( void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *tokens_grad, T *expert_grad, T *tks, T *logits_grad, int *mask1, int *mask2, int *dest1,
T *logits, T *logits_grad, int *dest2, const int s, const int e, const int c,
int *mask1, int *mask2, const int h) {
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256) if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>> moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
else // if (h < 512) else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>> moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
// else if (h < 1024) // else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>> // moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
// else // else
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>> // moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
} }
void cumsum_launch( void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
int *inputs, int *outputs,
const int s, const int e) {
if (s <= 256) if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e); cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
@ -569,32 +560,29 @@ void cumsum_launch(
// API FUNCTIONS -------------------------------- // API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \ using scalar_t = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::Half: { \
{ \
using scalar_t = at::Half; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
default: \ default: \
AT_ERROR(#NAME, " not implemented yet for specific data type.");\ AT_ERROR(#NAME, " not implemented yet for specific data type."); \
} }
torch::Tensor moe_dispatch_cuda_forward( torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros({ec, h}, auto res = torch::zeros(
{ec, h},
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
auto k = mask.size(0); auto k = mask.size(0);
@ -603,22 +591,20 @@ torch::Tensor moe_dispatch_cuda_forward(
moe_dpch_fwd_launch<scalar_t>( moe_dpch_fwd_launch<scalar_t>(
batch_tokens.data<scalar_t>(), res.data<scalar_t>(), batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), dest_idx[0].data<int>(),
s, h) k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
);
return res; return res;
} }
torch::Tensor moe_dispatch_cuda_backward( torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros({s, h}, auto res = torch::zeros(
torch::dtype(expert_grad.dtype()).device(expert_grad.device())); {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
auto k = mask.size(0); auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF( DISPATCH_FLOAT_AND_HALF(
@ -626,65 +612,62 @@ torch::Tensor moe_dispatch_cuda_backward(
moe_dpch_bwd_launch<scalar_t>( moe_dpch_bwd_launch<scalar_t>(
res.data<scalar_t>(), expert_grad.data<scalar_t>(), res.data<scalar_t>(), expert_grad.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), dest_idx[0].data<int>(),
s, h) k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
);
return res; return res;
} }
torch::Tensor moe_combine_cuda_forward( torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor logits, torch::Tensor mask,
torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
auto res = torch::zeros({s, h}, auto res = torch::zeros(
{s, h},
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
auto k = mask.size(0); auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF( DISPATCH_FLOAT_AND_HALF(
expert_tokens.scalar_type(), "moe combine forward", expert_tokens.scalar_type(), "moe combine forward",
moe_cb_fwd_launch<scalar_t>( moe_cb_fwd_launch<scalar_t>(
expert_tokens.data<scalar_t>(), res.data<scalar_t>(), logits.data<scalar_t>(), expert_tokens.data<scalar_t>(), res.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), logits.data<scalar_t>(), mask[0].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
s, e, c, h) k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
); h));
return res; return res;
} }
std::vector<torch::Tensor> moe_combine_cuda_backward( std::vector<torch::Tensor>
int s, int e, int c, int h, moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor tokens_grad, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor expert_tokens, torch::Tensor mask, torch::Tensor dest_idx) {
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype()); assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
auto egrad = torch::zeros({e * c, h}, auto egrad = torch::zeros(
{e * c, h},
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
wgrad = torch::zeros({s, e}, torch::dtype(logits.dtype()).device(logits.device())); wgrad = torch::zeros(
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
auto k = mask.size(0); auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF( DISPATCH_FLOAT_AND_HALF(
tokens_grad.scalar_type(), "moe combine backward", tokens_grad.scalar_type(), "moe combine backward",
moe_cb_bwd_launch<scalar_t>( moe_cb_bwd_launch<scalar_t>(
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(), expert_tokens.data<scalar_t>(), tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(),
logits.data<scalar_t>(), wgrad.data<scalar_t>(), expert_tokens.data<scalar_t>(), logits.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), wgrad.data<scalar_t>(), mask[0].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
s, e, c, h) k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
); h));
return {egrad, wgrad}; return {egrad, wgrad};
} }
@ -695,7 +678,8 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dtype() == torch::kInt32); assert(mask.dtype() == torch::kInt32);
const int s = mask.size(0), e = mask.size(1); const int s = mask.size(0), e = mask.size(1);
auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); auto res =
torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
cumsum_launch(mask.data<int>(), res.data<int>(), s, e); cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
return res; return res;