mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-29 10:24:46 +00:00
fix format (#583)
This commit is contained in:
parent
f08fc17f2b
commit
6582aedc94
@ -1,8 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include "block_reduce.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include "block_reduce.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
@ -10,15 +10,16 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack);
|
||||
@ -31,15 +32,16 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, pack);
|
||||
BlockStore(ts_store).Store(src_row + idx, pack);
|
||||
@ -47,20 +49,22 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
|
||||
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row1 + idx, pack);
|
||||
@ -69,17 +73,18 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
|
||||
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
@ -98,22 +103,22 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_fwd(
|
||||
T *src_row, T *dst_row,
|
||||
const T weight, const int cols) {
|
||||
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
|
||||
@ -127,19 +132,18 @@ __device__ void moe_cb_one_fwd(
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_bwd(
|
||||
T *src_row, T *dst_row, T *tks_row, T *weight_grad,
|
||||
const T weight, const int cols) {
|
||||
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
||||
T *weight_grad, const T weight, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
@ -165,19 +169,19 @@ __device__ void moe_cb_one_bwd(
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_fwd(
|
||||
T *src_row1, T *src_row2, T *dst_row,
|
||||
const T weight1, const T weight2, const int cols) {
|
||||
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
|
||||
const T weight1, const T weight2,
|
||||
const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
@ -196,25 +200,25 @@ __device__ void moe_cb_two_fwd(
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_bwd(
|
||||
T *src_row1, T *src_row2, T *dst_row,
|
||||
T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2,
|
||||
const T weight1, const T weight2, const int cols) {
|
||||
__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
|
||||
T *tks_row1, T *tks_row2, T *weight_grad1,
|
||||
T *weight_grad2, const T weight1,
|
||||
const T weight2, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T grad[pack_size], tokens1[pack_size], tokens2[pack_size],
|
||||
sgrad1[pack_size], sgrad2[pack_size];
|
||||
T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size],
|
||||
sgrad2[pack_size];
|
||||
float thread_sum[2] = {0, 0};
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, grad);
|
||||
@ -239,142 +243,125 @@ __device__ void moe_cb_two_bwd(
|
||||
*weight_grad1 = static_cast<T>(thread_sum[0]);
|
||||
else if (threadIdx.x == 1)
|
||||
*weight_grad2 = static_cast<T>(thread_sum[1]);
|
||||
|
||||
}
|
||||
|
||||
// DISPATCH KERNELS --------------------------------
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_fwd_selector(
|
||||
T *src_row, T *dst_row1, T *dst_row2, const int cols,
|
||||
const int indicator1, const int indicator2) {
|
||||
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_fwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, dst_row2, cols);
|
||||
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, cols);
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(
|
||||
src_row, dst_row2, cols);
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_bwd_selector(
|
||||
T *src_row, T *dst_row1, T *dst_row2, const int cols,
|
||||
const int indicator1, const int indicator2) {
|
||||
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_bwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, dst_row2, cols);
|
||||
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, cols);
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(
|
||||
src_row, dst_row2, cols);
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_fwd_kernel(
|
||||
T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int h) {
|
||||
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int h) {
|
||||
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_fwd_selector<T, block_size, pack_size>(
|
||||
batch_tokens + (row * h),
|
||||
expert_input + (dest1[row] * h), expert_input + (dest2[row] * h),
|
||||
h, mask1[row], indicator2);
|
||||
batch_tokens + (row * h), expert_input + (dest1[row] * h),
|
||||
expert_input + (dest2[row] * h), h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_bwd_kernel(
|
||||
T *tokens_grad, T *expert_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int h) {
|
||||
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int h) {
|
||||
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_bwd_selector<T, block_size, pack_size>(
|
||||
tokens_grad + (row * h),
|
||||
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
|
||||
h, mask1[row], indicator2);
|
||||
tokens_grad + (row * h), expert_grad + (dest1[row] * h),
|
||||
expert_grad + (dest2[row] * h), h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
// COMBINE KERNELS --------------------------------
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_fwd_selector(
|
||||
T *src_row1, T *src_row2, T *dst_row, const int cols,
|
||||
const T weight1, const T weight2,
|
||||
const int indicator1, const int indicator2) {
|
||||
__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
||||
const int cols, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_fwd<T, block_size, pack_size>(
|
||||
src_row1, src_row2, dst_row, weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(
|
||||
src_row1, dst_row, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(
|
||||
src_row2, dst_row, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_bwd_selector(
|
||||
T *src_row1, T *src_row2, T *dst_row, const int cols,
|
||||
T *tks_row1, T *tks_row2, T *wt_grad1, T *wt_grad2,
|
||||
const T weight1, const T weight2,
|
||||
const int indicator1, const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_bwd<T, block_size, pack_size>(
|
||||
src_row1, src_row2, dst_row,
|
||||
tks_row1, tks_row2, wt_grad1, wt_grad2,
|
||||
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(
|
||||
src_row1, dst_row, tks_row1, wt_grad1, weight1, cols);
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(src_row1, dst_row, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(
|
||||
src_row2, dst_row, tks_row2, wt_grad2, weight2, cols);
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(src_row2, dst_row, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
||||
const int cols, T *tks_row1, T *tks_row2,
|
||||
T *wt_grad1, T *wt_grad2, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
tks_row1, tks_row2, wt_grad1,
|
||||
wt_grad2, weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(src_row1, dst_row, tks_row1,
|
||||
wt_grad1, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(src_row2, dst_row, tks_row2,
|
||||
wt_grad2, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_fwd_kernel(
|
||||
T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int e, const int c, const int h) {
|
||||
__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
|
||||
T *logits, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int e, const int c,
|
||||
const int h) {
|
||||
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e);
|
||||
moe_cb_fwd_selector<T, block_size, pack_size>(
|
||||
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
|
||||
combine_tokens + (row * h), h,
|
||||
row_log[eid1], row_log[eid2],
|
||||
mask1[row], indicator2);
|
||||
combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row],
|
||||
indicator2);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_bwd_kernel(
|
||||
T *tokens_grad, T *expert_grad, T *tks,
|
||||
T *logits, T *logits_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
|
||||
T *logits, T *logits_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int e, const int c, const int h) {
|
||||
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
@ -382,25 +369,22 @@ __global__ void moe_cb_bwd_kernel(
|
||||
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
|
||||
moe_cb_bwd_selector<T, block_size, pack_size>(
|
||||
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
|
||||
tokens_grad + (row * h), h,
|
||||
tks + (dest1[row] * h), tks + (dest2[row] * h),
|
||||
row_grad + eid1, row_grad + eid2,
|
||||
row_log[eid1], row_log[eid2],
|
||||
mask1[row], indicator2);
|
||||
tokens_grad + (row * h), h, tks + (dest1[row] * h),
|
||||
tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1],
|
||||
row_log[eid2], mask1[row], indicator2);
|
||||
}
|
||||
|
||||
// CUMSUM KERNEL --------------------------------
|
||||
|
||||
template <int block_size, int pack_size>
|
||||
__global__ void cumsum_kernel(
|
||||
int *inputs, int *outputs,
|
||||
const int s, const int e) {
|
||||
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
|
||||
const int e) {
|
||||
|
||||
assert(s % pack_size == 0);
|
||||
constexpr int bpack_size = block_size * pack_size;
|
||||
int tid = threadIdx.x, bid = blockIdx.x,
|
||||
tps = tid * pack_size, last_sum = -1;
|
||||
__shared__ int temp[block_size + 1]; int pack[pack_size];
|
||||
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
|
||||
__shared__ int temp[block_size + 1];
|
||||
int pack[pack_size];
|
||||
|
||||
for (int idx = 0; idx < s; idx += bpack_size) {
|
||||
int offset = 1;
|
||||
@ -435,8 +419,7 @@ __global__ void cumsum_kernel(
|
||||
offset >>= 1;
|
||||
__syncthreads();
|
||||
if (tid < i) {
|
||||
int j = offset * (2 * tid + 1) - 1,
|
||||
k = j + offset, ts = temp[j];
|
||||
int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j];
|
||||
temp[j] = temp[k];
|
||||
temp[k] += ts;
|
||||
}
|
||||
@ -467,92 +450,100 @@ __global__ void cumsum_kernel(
|
||||
// LAUNCH FUNCTIONS --------------------------------
|
||||
|
||||
template <typename T>
|
||||
void moe_dpch_fwd_launch(
|
||||
T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int s, const int h) {
|
||||
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2, const int s,
|
||||
const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_dpch_fwd_kernel<T, 32, 4><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_fwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_fwd_kernel<T, 32, 8><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_fwd_kernel<T, 32, 8>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_fwd_kernel<T, 32, 16><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_fwd_kernel<T, 32, 16>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_fwd_kernel<T, 64, 16><<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_fwd_kernel<T, 64, 16>
|
||||
<<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_fwd_kernel<T, 128, 16><<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_fwd_kernel<T, 128, 16>
|
||||
<<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_dpch_bwd_launch(
|
||||
T *tokens_grad, T *expert_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int s, const int h) {
|
||||
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int s, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_dpch_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_bwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_bwd_kernel<T, 32, 8><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_bwd_kernel<T, 32, 8>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_bwd_kernel<T, 32, 16><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_bwd_kernel<T, 32, 16>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_bwd_kernel<T, 64, 16><<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_bwd_kernel<T, 64, 16>
|
||||
<<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_bwd_kernel<T, 128, 16><<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
moe_dpch_bwd_kernel<T, 128, 16>
|
||||
<<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_cb_fwd_launch(
|
||||
T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2, int *dest1, int *dest2,
|
||||
const int s, const int e, const int c, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 512)
|
||||
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 1024)
|
||||
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 2048)
|
||||
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else
|
||||
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1,
|
||||
dest2, e, c, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_cb_bwd_launch(
|
||||
T *tokens_grad, T *expert_grad, T *tks,
|
||||
T *logits, T *logits_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int s, const int e, const int c, const int h) {
|
||||
void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
|
||||
T *logits_grad, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int s, const int e, const int c,
|
||||
const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>
|
||||
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
else // if (h < 512)
|
||||
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>
|
||||
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
// else if (h < 1024)
|
||||
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
|
||||
// dest1, dest2, e, c, h);
|
||||
// else
|
||||
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
|
||||
// dest1, dest2, e, c, h);
|
||||
}
|
||||
|
||||
void cumsum_launch(
|
||||
int *inputs, int *outputs,
|
||||
const int s, const int e) {
|
||||
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
||||
|
||||
if (s <= 256)
|
||||
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
|
||||
@ -569,16 +560,13 @@ void cumsum_launch(
|
||||
// API FUNCTIONS --------------------------------
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
@ -587,14 +575,14 @@ void cumsum_launch(
|
||||
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros({ec, h},
|
||||
auto res = torch::zeros(
|
||||
{ec, h},
|
||||
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
@ -603,22 +591,20 @@ torch::Tensor moe_dispatch_cuda_forward(
|
||||
moe_dpch_fwd_launch<scalar_t>(
|
||||
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, h)
|
||||
);
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_backward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros({s, h},
|
||||
torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
||||
auto res = torch::zeros(
|
||||
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
@ -626,65 +612,62 @@ torch::Tensor moe_dispatch_cuda_backward(
|
||||
moe_dpch_bwd_launch<scalar_t>(
|
||||
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, h)
|
||||
);
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_combine_cuda_forward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto res = torch::zeros({s, h},
|
||||
auto res = torch::zeros(
|
||||
{s, h},
|
||||
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_tokens.scalar_type(), "moe combine forward",
|
||||
moe_cb_fwd_launch<scalar_t>(
|
||||
expert_tokens.data<scalar_t>(), res.data<scalar_t>(), logits.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, e, c, h)
|
||||
);
|
||||
expert_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
logits.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
std::vector<torch::Tensor>
|
||||
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits,
|
||||
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
assert(tokens_grad.dtype() == expert_tokens.dtype());
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto egrad = torch::zeros({e * c, h},
|
||||
auto egrad = torch::zeros(
|
||||
{e * c, h},
|
||||
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
|
||||
wgrad = torch::zeros({s, e}, torch::dtype(logits.dtype()).device(logits.device()));
|
||||
wgrad = torch::zeros(
|
||||
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tokens_grad.scalar_type(), "moe combine backward",
|
||||
moe_cb_bwd_launch<scalar_t>(
|
||||
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(), expert_tokens.data<scalar_t>(),
|
||||
logits.data<scalar_t>(), wgrad.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, e, c, h)
|
||||
);
|
||||
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(),
|
||||
expert_tokens.data<scalar_t>(), logits.data<scalar_t>(),
|
||||
wgrad.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return {egrad, wgrad};
|
||||
}
|
||||
@ -695,7 +678,8 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
|
||||
assert(mask.dtype() == torch::kInt32);
|
||||
|
||||
const int s = mask.size(0), e = mask.size(1);
|
||||
auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
|
||||
auto res =
|
||||
torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
|
||||
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
|
||||
|
||||
return res;
|
||||
|
Loading…
Reference in New Issue
Block a user