From 219df6e685c7ed1d0b712ce673debc8684bc3779 Mon Sep 17 00:00:00 2001 From: 1SAA Date: Fri, 18 Feb 2022 20:42:31 +0800 Subject: [PATCH] Optimized MoE layer and fixed some bugs; Decreased moe tests; Added FFNExperts and ViTMoE model --- .pre-commit-config.yaml | 2 +- colossalai/global_variables.py | 4 + .../csrc/layer_norm_cuda_kernel.cu | 2 +- .../kernel/cuda_native/csrc/moe_cuda.cpp | 118 +++ .../cuda_native/csrc/moe_cuda_kernel.cu | 702 ++++++++++++++++++ colossalai/nn/layer/moe/__init__.py | 11 +- colossalai/nn/layer/moe/_operation.py | 89 ++- colossalai/nn/layer/moe/experts.py | 96 +++ colossalai/nn/layer/moe/layers.py | 300 ++++---- colossalai/nn/layer/moe/utils.py | 32 + model_zoo/moe/models.py | 104 ++- setup.py | 4 + tests/test_moe/short_test.py | 97 +++ tests/test_moe/test_top1.py | 97 +++ tests/test_moe/test_top2.py | 97 +++ 15 files changed, 1552 insertions(+), 203 deletions(-) create mode 100644 colossalai/kernel/cuda_native/csrc/moe_cuda.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu create mode 100644 colossalai/nn/layer/moe/experts.py create mode 100644 colossalai/nn/layer/moe/utils.py create mode 100644 tests/test_moe/short_test.py create mode 100644 tests/test_moe/test_top1.py create mode 100644 tests/test_moe/test_top2.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df741de47..98ecd0314 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,6 @@ repos: hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v13.0.0 + rev: v13.0.1 hooks: - id: clang-format diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py index 04f6e891e..f6eab02be 100644 --- a/colossalai/global_variables.py +++ b/colossalai/global_variables.py @@ -56,6 +56,7 @@ class MoeEnv: self.data_parallel_size = None self.model_parallel_size = None self.aux_loss = None + self.enable_cuda = True def setup(self, moe_model_size): from .core import global_context as gpc @@ -71,6 +72,9 @@ class MoeEnv: def is_initialized(self): return self.model_parallel_size is not None + def set_cuda_false(self): + self.enable_cuda = False + def reset_loss(self): self.aux_loss = 0 diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu index dc52f8019..ae8fc871a 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -5,7 +5,7 @@ #include "ATen/ATen.h" #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" -#include +#include "ATen/cuda/DeviceUtils.cuh" #include #include diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp new file mode 100644 index 000000000..063fbc664 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -0,0 +1,118 @@ +#include + + +torch::Tensor moe_dispatch_cuda_forward( + int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward( + int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward( + int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward( + int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward( + s, ec, h, + batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward( + int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward( + s, ec, h, + expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward( + int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward( + s, e, c, h, + expert_tokens, logits, mask, dest_idx); +} + +std::vector moe_combine_backward( + int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward( + s, e, c, h, + tokens_grad, expert_tokens, logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, + "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu new file mode 100644 index 000000000..ea0b45a45 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -0,0 +1,702 @@ +#include +#include +#include +#include +#include "block_reduce.h" + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd( + T *src_row, T *dst_row, + const T weight, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd( + T *src_row, T *dst_row, T *tks_row, T *weight_grad, + const T weight, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + + blockReduce(&thread_sum); + + if (threadIdx.x == 0) + *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd( + T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd( + T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2, + const T weight1, const T weight2, const int cols) { + + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], + sgrad1[pack_size], sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + + #pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + blockReduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); + +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector( + T *src_row, T *dst_row1, T *dst_row2, const int cols, + const int indicator1, const int indicator2) { + + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd( + src_row, dst_row1, dst_row2, cols); + else if (indicator1 != 0) + moe_dpch_one_fwd( + src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd( + src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector( + T *src_row, T *dst_row1, T *dst_row2, const int cols, + const int indicator1, const int indicator2) { + + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd( + src_row, dst_row1, dst_row2, cols); + else if (indicator1 != 0) + moe_dpch_one_bwd( + src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd( + src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel( + T *batch_tokens, T *expert_input, + int *mask1, int *mask2, + int *dest1, int *dest2, const int h) { + + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), + expert_input + (dest1[row] * h), expert_input + (dest2[row] * h), + h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel( + T *tokens_grad, T *expert_grad, + int *mask1, int *mask2, + int *dest1, int *dest2, const int h) { + + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector( + T *src_row1, T *src_row2, T *dst_row, const int cols, + const T weight1, const T weight2, + const int indicator1, const int indicator2) { + + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd( + src_row1, src_row2, dst_row, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd( + src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd( + src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector( + T *src_row1, T *src_row2, T *dst_row, const int cols, + T *tks_row1, T *tks_row2, T *wt_grad1, T *wt_grad2, + const T weight1, const T weight2, + const int indicator1, const int indicator2) { + + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd( + src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, wt_grad2, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd( + src_row1, dst_row, tks_row1, wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd( + src_row2, dst_row, tks_row2, wt_grad2, weight2, cols); + else + return; +} + + +template +__global__ void moe_cb_fwd_kernel( + T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, + int *dest1, int *dest2, + const int e, const int c, const int h) { + + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, + row_log[eid1], row_log[eid2], + mask1[row], indicator2); +} + +template +__global__ void moe_cb_bwd_kernel( + T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, + int *mask1, int *mask2, + int *dest1, int *dest2, + const int e, const int c, const int h) { + + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, + tks + (dest1[row] * h), tks + (dest2[row] * h), + row_grad + eid1, row_grad + eid2, + row_log[eid1], row_log[eid2], + mask1[row], indicator2); +} + +//CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel( + int *inputs, int *outputs, + const int s, const int e) { + + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, + tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; + #pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } + #pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, + k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) + temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; + #pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +//LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch( + T *batch_tokens, T *expert_input, + int *mask1, int *mask2, + int *dest1, int *dest2, + const int s, const int h) { + + if (h < 256) + moe_dpch_fwd_kernel<<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel<<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel<<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel<<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel<<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch( + T *tokens_grad, T *expert_grad, + int *mask1, int *mask2, + int *dest1, int *dest2, + const int s, const int h) { + + if (h < 256) + moe_dpch_bwd_kernel<<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel<<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel<<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel<<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel<<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch( + T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, + int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + + if (h < 256) + moe_cb_fwd_kernel<<>> + (expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>> + (expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>> + (expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>> + (expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); + else + moe_cb_fwd_kernel<<>> + (expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); +} + +template +void moe_cb_bwd_launch( + T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, + int *mask1, int *mask2, + int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + + if (h < 256) + moe_cb_bwd_kernel<<>> + (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>> + (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); +} + +void cumsum_launch( + int *inputs, int *outputs, + const int s, const int e) { + + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type.");\ + } + +torch::Tensor moe_dispatch_cuda_forward( + int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + + assert(h % 16 == 0); + auto res = torch::zeros({ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data(), res.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), k == 1 ? dest_idx[0].data() : dest_idx[1].data(), + s, h) + ); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward( + int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + + assert(h % 16 == 0); + auto res = torch::zeros({s, h}, + torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data(), expert_grad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), k == 1 ? dest_idx[0].data() : dest_idx[1].data(), + s, h) + ); + + return res; +} + +torch::Tensor moe_combine_cuda_forward( + int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros({s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data(), res.data(), logits.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), k == 1 ? dest_idx[0].data() : dest_idx[1].data(), + s, e, c, h) + ); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros({e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros({s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data(), egrad.data(), expert_tokens.data(), + logits.data(), wgrad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), k == 1 ? dest_idx[0].data() : dest_idx[1].data(), + s, e, c, h) + ); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data(), res.data(), s, e); + + return res; +} diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index e75aff6ed..95a9884a9 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,8 +1,5 @@ -from ._operation import AllToAll -from .layers import Experts, MoeLayer, \ - NormalNoiseGenerator, Top1Router, Top2Router +from .experts import Experts, FFNExperts +from .layers import MoeLayer, Top1Router, Top2Router +from .utils import NormalNoiseGenerator -__all__ = [ - 'AllToAll', 'Experts', 'Top1Router', 'Top2Router', - 'MoeLayer', 'NormalNoiseGenerator' -] \ No newline at end of file +__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator'] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index fd2720fb9..053de0ef6 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -6,16 +6,26 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from typing import Any, Tuple +U_CUDA_MODE = False +try: + import colossal_moe_cuda + + U_CUDA_MODE = True +except ImportError: + print("If you want to activate cuda mode for MoE, please install with cuda_ext!") + class AllToAll(torch.autograd.Function): """Dispatches input tensor [e, c, h] to all experts by all_to_all_single operation in torch.distributed. """ + @staticmethod def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: - ctx.parallel_mode = parallel_mode + if ctx is not None: + ctx.parallel_mode = parallel_mode if not inputs.is_contiguous(): inputs = inputs.contiguous() @@ -26,4 +36,79 @@ class AllToAll(torch.autograd.Function): @staticmethod def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None + return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None + + +class MoeDispatch(torch.autograd.Function): + + @staticmethod + def forward(ctx, tokens, mask, dest_idx, ec): + s = tokens.size(0) + h = tokens.size(1) + + expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + + ctx.save_for_backward(mask, dest_idx) + ctx.s = s + ctx.h = h + ctx.ec = ec + + return expert_input + + @staticmethod + def backward(ctx, output_grad): + mask, dest_idx = ctx.saved_tensors + d_tokens = colossal_moe_cuda.dispatch_backward( + ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + return d_tokens, None, None, None + + +class MoeCombine(torch.autograd.Function): + + @staticmethod + def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): + assert logits.dtype == torch.float32 + + s = logits.size(0) + e = logits.size(1) + c = ec // e + h = expert_tokens.size(-1) + + fp16_flag = (expert_tokens.dtype == torch.float16) + cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens + ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, + cb_input, logits, + mask, dest_idx) + output = ctokens.to(torch.float16) if fp16_flag else ctokens + + ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) + ctx.s = s + ctx.e = e + ctx.c = c + ctx.h = h + ctx.fp16_flag = fp16_flag + + return output + + @staticmethod + def backward(ctx, tokens_grad): + expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ + else tokens_grad + cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens + d_expert, d_logits = colossal_moe_cuda.combine_backward( + ctx.s, ctx.e, ctx.c, ctx.h, + cb_grad, cb_input, logits, mask, dest_idx) + d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert + + return d_expert, d_logits, None, None, None + + +def moe_cumsum(inputs: Tensor): + dim0 = inputs.size(0) + flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) + if flag and U_CUDA_MODE: + return colossal_moe_cuda.cumsum_sub_one(inputs) + else: + return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py new file mode 100644 index 000000000..4688c9ce7 --- /dev/null +++ b/colossalai/nn/layer/moe/experts.py @@ -0,0 +1,96 @@ +import math + +import torch +import torch.nn as nn +from colossalai.global_variables import moe_env +from colossalai.context import ParallelMode, seed +from colossalai.utils import get_current_device + + +class Experts(nn.Module): + """A wrapper class to create experts. It will create E experts across the + moe model parallel group, where E is the number of experts. Every expert + is a instence of the class, 'expert' in initialization parameters. + + :param expert: The class of all experts + :param num_experts: The number of experts + :param expert_args: Args used to initialize experts + + :type num_experts: int + """ + + def __init__(self, expert, num_experts, **expert_args): + super().__init__() + + assert num_experts % moe_env.model_parallel_size == 0, \ + "The number of experts should be divied by moe model size" + + num_local_experts = num_experts // moe_env.model_parallel_size + with seed(ParallelMode.MOE_MODEL): + self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)]) + self.num_local_experts = num_local_experts + for exp in self.experts: + for param in exp.parameters(): + param.__setattr__('moe_param', True) + + def forward(self, inputs): + expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) + expert_output = [] + + for i in range(self.num_local_experts): + expert_output.append(self.experts[i](expert_input[i])) + + output = torch.cat(expert_output, dim=1).contiguous() + return output + + +class FFNExperts(nn.Module): + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__() + + assert num_experts % moe_env.model_parallel_size == 0, \ + "The number of experts should be divied by moe model size" + + num_local_experts = num_experts // moe_env.model_parallel_size + + self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + for param in self.parameters(): + param.__setattr__('moe_param', True) + + def forward(self, inputs): # x [g, el, c, h] + + el = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(el, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, inter, self.w2) + with seed(ParallelMode.TENSOR): + outputs = self.drop(out_model) # outputs [el, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index ab9c72395..0abe7ac8c 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -3,70 +3,13 @@ import math import torch import torch.nn as nn import torch.nn.functional as F -from torch.cuda.amp import autocast +import torch.distributed as dist +from colossalai.core import global_context as gpc from colossalai.global_variables import moe_env -from colossalai.context import ParallelMode, seed +from colossalai.context import ParallelMode from colossalai.utils import get_current_device -from ._operation import AllToAll - - -class NormalNoiseGenerator: - """Generates a random noisy mask for logtis tensor. - - All noise is generated from a normal distribution (0, 1 / E^2), where - E = the number of experts. - - :param num_experts: The number of experts - :type num_experts: int - """ - - def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device()) - ).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.normal(inputs.shape) - return inputs + noisy - - -class Experts(nn.Module): - """A wrapper class to create experts. It will create E experts across the - moe model parallel group, where E is the number of experts. Every expert - is a instence of the class, 'expert' in initialization parameters. - - :param expert: The class of all experts - :param num_experts: The number of experts - :param expert_args: Args used to initialize experts - - :type num_experts: int - """ - - def __init__(self, expert, num_experts, **expert_args): - super().__init__() - - assert num_experts % moe_env.model_parallel_size == 0, \ - "The number of experts should be divied by moe model size" - - num_local_experts = num_experts // moe_env.model_parallel_size - with seed(ParallelMode.MOE_MODEL): - self.experts = nn.ModuleList([ - expert(**expert_args) for _ in range(num_local_experts)]) - self.num_local_experts = num_local_experts - for exp in self.experts: - for param in exp.parameters(): - param.__setattr__('moe_param', 1) - - def forward(self, inputs): - expert_input = torch.chunk(inputs, self.num_local_experts, dim=0) - expert_output = [] - - for i in range(self.num_local_experts): - expert_output.append(self.experts[i](expert_input[i])) - - output = torch.cat(expert_output, dim=0) - return output +from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum +from .utils import autocast_softmax class Top1Router(nn.Module): @@ -83,63 +26,79 @@ class Top1Router(nn.Module): :type noisy_func: Callable, optional """ - def __init__(self, - capacity_factor: float, - min_capacity: int, - noisy_func=None): + def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None): super().__init__() self.capacity_factor = capacity_factor self.min_capacity = min_capacity + self.select_policy = select_policy self.noisy_func = noisy_func - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device())).rsample - def get_capacity(self, logits_shape): - capacity = math.ceil(self.capacity_factor * - logits_shape[0] / logits_shape[1]) - if capacity < self.min_capacity: - capacity = self.min_capacity + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def get_capacity( + self, + logits_shape, + ): + capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 return capacity - def forward(self, inputs): + def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): if self.noisy_func is not None: inputs_noisy = self.noisy_func(inputs) else: inputs_noisy = inputs - logits = F.softmax(inputs, dim=1) - - num_experts = logits.shape[1] + logits = autocast_softmax(inputs, dim=-1) + num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) - expert_idx = torch.argmax(inputs_noisy, dim=1) - expert_mask = F.one_hot(expert_idx, num_classes=num_experts) - expert_mask_f = expert_mask.float() + top1_idx = torch.argmax(inputs_noisy, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu') + if self.training: + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + moe_env.add_loss(l_aux) + else: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) + capacity = max_num.item() - me = torch.mean(logits, dim=0) - ce = torch.mean(expert_mask_f, dim=0) - l_aux = torch.sum(me * ce) * num_experts - moe_env.add_loss(l_aux) + if not self.training: + ranks = moe_cumsum(mask) + elif self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") - rand_mask = expert_mask * self.uniform(logits.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + ranks = torch.sum(mask * ranks, dim=-1) - dispatch_mask = \ - expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1) - - locations = torch.cumsum(dispatch_mask, dim=0) - 1 - locations = torch.sum(dispatch_mask * locations, dim=1) - locations = F.one_hot(locations, num_classes=capacity) - - logits = logits * dispatch_mask - combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1) - - sec_mask = combine_weights.bool() - return combine_weights, sec_mask, exp_counts + if cuda_mode: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask class Top2Router(nn.Module): @@ -159,53 +118,67 @@ class Top2Router(nn.Module): self.noisy_func = noisy_func def get_capacity(self, logits_shape): - capacity = math.ceil(2 * self.capacity_factor * - logits_shape[0] / logits_shape[1]) + capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + assert capacity > 0 return capacity - def forward(self, inputs): + def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): + # inputs: [s, h] if self.noisy_func is not None: inputs = self.noisy_func(inputs) - logits = F.softmax(inputs, dim=-1) + logits = autocast_softmax(inputs, dim=-1) # logits: [s, e] num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) - _, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True) - top1_idx = expert_idx[:, 0] - top2_idx = expert_idx[:, 1] + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - mask1 = F.one_hot(top1_idx, num_classes=num_experts) - mask2 = F.one_hot(top2_idx, num_classes=num_experts) + cmask = (mask1 + mask2) # loss: [s, e] + if self.training: + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 + moe_env.add_loss(l_aux) + else: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) + capacity = max_num.item() - loss_mask = (mask1 + mask2) - exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu') - me = torch.mean(logits, dim=0) - ce = torch.mean(loss_mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 - moe_env.add_loss(l_aux) + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) - locations1 = torch.cumsum(mask1, dim=0) - 1 - locations2 = torch.cumsum(mask2, dim=0) - 1 - locations2 += torch.sum(mask1, dim=0, keepdim=True) + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) - mask1 *= torch.lt(locations1, capacity) - mask2 *= torch.lt(locations2, capacity) + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) - weight1 = mask1 * logits - weight2 = mask2 * logits + if cuda_mode: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) - locations1 = torch.sum(mask1 * locations1, dim=1) - locations2 = torch.sum(mask2 * locations2, dim=1) - locations1_sc = F.one_hot(locations1, num_classes=capacity) - locations2_sc = F.one_hot(locations2, num_classes=capacity) + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1) - combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1) - combine_weights = combine_weights1 + combine_weights2 - sec_mask = combine_weights.bool() + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) - return combine_weights, sec_mask, exp_counts + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask class MoeLayer(nn.Module): @@ -225,52 +198,47 @@ class MoeLayer(nn.Module): :type experts: nn.Module """ - def __init__(self, - dim_model: int, - num_experts: int, - router: nn.Module, - experts: nn.Module): + def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module): super().__init__() self.d_model = dim_model self.num_experts = num_experts - self.gate = nn.Linear(dim_model, num_experts, device=get_current_device()) + self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device()) self.router = router self.experts = experts + self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False - def _router_part(self, tokens: torch.Tensor): - gate_output = self.gate(tokens) - return self.router(gate_output) + def expert_part(self, expert_input: torch.Tensor): + expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL) - def router_part(self, tokens: torch.Tensor): - autocast_context = torch.is_autocast_enabled() - if not autocast_context: - return self._router_part(tokens) - else: - with autocast(enabled=False): - if tokens.dtype == torch.float16: - input_tokens = tokens.float() - else: - input_tokens = tokens - return self._router_part(input_tokens) + input_shape = expert_input.shape + + expert_input = expert_input.reshape(moe_env.model_parallel_size, + self.num_experts // moe_env.model_parallel_size, -1, self.d_model) + + expert_output = self.experts(expert_input) + expert_output = expert_output.reshape(input_shape) + + expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL) + return expert_output def forward(self, inputs: torch.Tensor) -> torch.Tensor: tokens = inputs.reshape(-1, self.d_model) + gate_output = self.gate(tokens) + router_res = self.router(gate_output, self.cuda_mode) - combine_weights, sec_mask, exp_counts = self.router_part(tokens) + if self.cuda_mode: + logits, mask, dest_idx, ec = router_res + expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec) + expert_output = self.expert_part(expert_input) + ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec) + else: + combine_weights, sec_mask = router_res + sec_mask_f = sec_mask.type_as(inputs) + expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + expert_output = self.expert_part(expert_input) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ret = torch.matmul(combine_weights, expert_output) - sec_mask_f = sec_mask.type_as(inputs) - dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - - dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL) - - expert_output = self.experts(dispatch_data) - - expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL) - - combine_weights = combine_weights.view(combine_weights.shape[0], -1) - expert_output = expert_output.view(-1, expert_output.shape[-1]) - - ret = torch.matmul(combine_weights, expert_output) ret = ret.reshape(inputs.shape) - return ret diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py new file mode 100644 index 000000000..060741c4f --- /dev/null +++ b/colossalai/nn/layer/moe/utils.py @@ -0,0 +1,32 @@ +import torch +import torch.nn.functional as F +from colossalai.utils import get_current_device + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logtis tensor. + + All noise is generated from a normal distribution (0, 1 / E^2), where + E = the number of experts. + + :param num_experts: The number of experts + :type num_experts: int + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device()) + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +def autocast_softmax(inputs: torch.Tensor, dim: int): + assert inputs.dtype in {torch.float16, torch.float32} + fp16_flag = (inputs.dtype == torch.float16) + sm_input = inputs.to(torch.float32) if fp16_flag else inputs + sm_output = F.softmax(sm_input, dim) + return sm_output diff --git a/model_zoo/moe/models.py b/model_zoo/moe/models.py index 2be7b21ae..cffd837a4 100644 --- a/model_zoo/moe/models.py +++ b/model_zoo/moe/models.py @@ -4,7 +4,7 @@ import torch.nn as nn from colossalai.context import ParallelMode from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ WrappedDropout as Dropout, WrappedDropPath as DropPath -from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator +from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator from .util import moe_sa_args, moe_mlp_args from ..helper import TransformerLayer from colossalai.global_variables import moe_env @@ -81,6 +81,7 @@ class VanillaFFN(nn.Module): class Widenet(nn.Module): + def __init__(self, num_experts: int, capacity_factor: float, @@ -98,43 +99,33 @@ class Widenet(nn.Module): drop_path: float = 0.): super().__init__() - embedding = VanillaPatchEmbedding( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_size=d_model) + embedding = VanillaPatchEmbedding(img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_size=d_model) embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) shared_sa = VanillaSelfAttention(**moe_sa_args( - d_model=d_model, n_heads=num_heads, d_kv=d_kv, - attention_drop=attention_drop, drop_rate=drop_rate)) + d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) noisy_func = NormalNoiseGenerator(num_experts) shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) - shared_experts = Experts(expert=VanillaFFN, - num_experts=num_experts, - **moe_mlp_args( - d_model=d_model, - d_ff=d_ff, - drop_rate=drop_rate - )) + shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate) # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] blocks = [ - TransformerLayer( - att=shared_sa, - ffn=MoeLayer(dim_model=d_model, num_experts=num_experts, - router=shared_router, experts=shared_experts), - norm1=nn.LayerNorm(d_model, eps=1e-6), - norm2=nn.LayerNorm(d_model, eps=1e-6), - droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR) - ) - for i in range(depth) + TransformerLayer(att=shared_sa, + ffn=MoeLayer(dim_model=d_model, + num_experts=num_experts, + router=shared_router, + experts=shared_experts), + norm1=nn.LayerNorm(d_model, eps=1e-6), + norm2=nn.LayerNorm(d_model, eps=1e-6), + droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)) for i in range(depth) ] norm = nn.LayerNorm(d_model, eps=1e-6) - self.linear = VanillaClassifier(in_features=d_model, - num_classes=num_classes) + self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes) nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm) @@ -145,3 +136,64 @@ class Widenet(nn.Module): x = torch.mean(x, dim=1) x = self.linear(x) return x + + +class ViTMoE(nn.Module): + + def __init__(self, + num_experts: int, + capacity_factor: float, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + num_classes: int = 1000, + depth: int = 12, + d_model: int = 768, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 3072, + attention_drop: float = 0., + drop_rate: float = 0.1, + drop_path: float = 0.): + super().__init__() + + embedding = VanillaPatchEmbedding(img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_size=d_model) + embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) + + noisy_func = NormalNoiseGenerator(num_experts) + router = Top2Router(capacity_factor, noisy_func=noisy_func) + + assert depth % 2 == 0 + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] + blocks = [] + for i in range(depth): + sa = VanillaSelfAttention(**moe_sa_args( + d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) + ffn = VanillaFFN(**moe_mlp_args( + d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \ + MoeLayer(dim_model=d_model, num_experts=num_experts, router=router, + experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)) + layer = TransformerLayer(att=sa, + ffn=ffn, + norm1=nn.LayerNorm(d_model, eps=1e-6), + norm2=nn.LayerNorm(d_model, eps=1e-6), + droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)) + blocks.append(layer) + + norm = nn.LayerNorm(d_model, eps=1e-6) + self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm) + + def forward(self, x): + moe_env.reset_loss() + x = self.vitmoe(x) + x = torch.mean(x, dim=1) + x = self.linear(x) + return x diff --git a/setup.py b/setup.py index 57328d28a..8b25229f7 100644 --- a/setup.py +++ b/setup.py @@ -162,6 +162,10 @@ if build_cuda_ext: ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], extra_cuda_flags + cc_flag)) + ext_modules.append(cuda_ext_helper('colossal_moe_cuda', + ['moe_cuda.cpp', 'moe_cuda_kernel.cu'], + extra_cuda_flags + cc_flag)) + extra_cuda_flags = ['-maxrregcount=50'] ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda', diff --git a/tests/test_moe/short_test.py b/tests/test_moe/short_test.py new file mode 100644 index 000000000..3b919345d --- /dev/null +++ b/tests/test_moe/short_test.py @@ -0,0 +1,97 @@ +import os +from functools import partial +from pathlib import Path +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import free_port, get_current_device +from colossalai.nn.layer.moe import Top2Router, MoeLayer +from colossalai.global_variables import moe_env + + +BATCH_SIZE = 32 +NUM_EXPERTS = 4 +CONFIG = dict(parallel=dict(moe=dict(size=4))) + + +def check_equal(A, B, atol=1e-06): + assert torch.allclose(A, B, rtol=0, atol=atol) is True + + +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # torch.set_printoptions(precision=30) + torch.backends.cuda.matmul.allow_tf32 = False + local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + torch.manual_seed(rs + local_rank) + moe_env.reset_loss() + tokens = torch.randn(BATCH_SIZE, hidden_size, + dtype=data_type, device=get_current_device(), requires_grad=True) + # print(f"tokens:\n{tokens}") + router = Top2Router(1) + layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity()) + if data_type == torch.float16: + layer = layer.half() + layer.cuda_mode = False + + old_out = layer(tokens) + # print(f"old output:\n{old_out}") + + ech = old_out.shape + grad = torch.randn(ech, device=get_current_device()) + old_out.backward(grad) + + o_tk_grad = tokens.grad.data.clone() + o_gt_grad = layer.gate.weight.grad.data.clone() + + tokens.grad.zero_() + layer.gate.weight.grad.zero_() + + layer.cuda_mode = True + new_out = layer(tokens) + + # print(torch.max(torch.abs(old_out - new_out))) + if data_type == torch.float32: + check_equal(old_out, new_out) + else: + check_equal(old_out, new_out, 1e-2) + # print(f"forward functions passed") + + # print(f"new output:\n{new_out}") + new_out.backward(grad) + n_tk_grad = tokens.grad.data.clone() + n_gt_grad = layer.gate.weight.grad.data.clone() + + # print(torch.max(torch.abs(o_tk_grad - n_tk_grad))) + if data_type == torch.float32: + check_equal(o_tk_grad, n_tk_grad) + else: + check_equal(o_tk_grad, o_tk_grad, 1e-2) + # print(f"tokens gradient passed") + + # print(torch.max(torch.abs(o_gt_grad - n_gt_grad))) + if data_type == torch.float32: + check_equal(o_gt_grad, n_gt_grad, 5e-05) + else: + check_equal(o_gt_grad, n_gt_grad, 2e-01) + # print(f"linear weight gradient passed") + + +@pytest.mark.dist +@pytest.mark.parametrize("rs", [131]) +@pytest.mark.parametrize("hidden_size", [32, 144]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) +def test_moe_top2(rs, hidden_size, data_type): + world_size = 4 + run_func = partial(run_routing, world_size=world_size, port=free_port(), + rs=rs, hidden_size=hidden_size, data_type=data_type) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_top2(2, 256, torch.float16) diff --git a/tests/test_moe/test_top1.py b/tests/test_moe/test_top1.py new file mode 100644 index 000000000..11986af44 --- /dev/null +++ b/tests/test_moe/test_top1.py @@ -0,0 +1,97 @@ +from functools import partial +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import free_port, get_current_device +from colossalai.nn.layer.moe import Top1Router, MoeLayer +from colossalai.global_variables import moe_env + +BATCH_SIZE = 32 +NUM_EXPERTS = 4 +CONFIG = dict(parallel=dict(moe=dict(size=4))) + + +def check_equal(A, B, atol=1e-06): + assert torch.allclose(A, B, rtol=0, atol=atol) is True + + +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # torch.set_printoptions(precision=30) + torch.backends.cuda.matmul.allow_tf32 = False + local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + torch.manual_seed(rs + local_rank) + moe_env.reset_loss() + tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + # print(f"tokens:\n{tokens}") + router = Top1Router(1) + layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity()) + if data_type == torch.float16: + layer = layer.half() + layer.cuda_mode = False + + old_out = layer(tokens) + # print(f"old output:\n{old_out}") + + ech = old_out.shape + grad = torch.randn(ech, device=get_current_device()) + old_out.backward(grad) + + o_tk_grad = tokens.grad.data.clone() + o_gt_grad = layer.gate.weight.grad.data.clone() + + tokens.grad.zero_() + layer.gate.weight.grad.zero_() + + layer.cuda_mode = True + new_out = layer(tokens) + + # print(torch.max(torch.abs(old_out - new_out))) + if data_type == torch.float32: + check_equal(old_out, new_out) + else: + check_equal(old_out, new_out, 1e-2) + # print(f"forward functions passed") + + # print(f"new output:\n{new_out}") + new_out.backward(grad) + n_tk_grad = tokens.grad.data.clone() + n_gt_grad = layer.gate.weight.grad.data.clone() + + # print(torch.max(torch.abs(o_tk_grad - n_tk_grad))) + if data_type == torch.float32: + check_equal(o_tk_grad, n_tk_grad) + else: + check_equal(o_tk_grad, o_tk_grad, 1e-2) + # print(f"tokens gradient passed") + + # print(torch.max(torch.abs(o_gt_grad - n_gt_grad))) + if data_type == torch.float32: + check_equal(o_gt_grad, n_gt_grad, 5e-05) + else: + check_equal(o_gt_grad, n_gt_grad, 2e-01) + # print(f"linear weight gradient passed") + + +@pytest.mark.skip(reason="Should be activated for detailed tests") +@pytest.mark.parametrize("rs", [2, 42, 60]) +@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) +def test_moe_top2(rs, hidden_size, data_type): + world_size = 4 + run_func = partial(run_routing, + world_size=world_size, + port=free_port(), + rs=rs, + hidden_size=hidden_size, + data_type=data_type) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_top2(60, 512, torch.float16) diff --git a/tests/test_moe/test_top2.py b/tests/test_moe/test_top2.py new file mode 100644 index 000000000..41500530f --- /dev/null +++ b/tests/test_moe/test_top2.py @@ -0,0 +1,97 @@ +from functools import partial +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils import free_port, get_current_device +from colossalai.nn.layer.moe import Top2Router, MoeLayer +from colossalai.global_variables import moe_env + +BATCH_SIZE = 32 +NUM_EXPERTS = 4 +CONFIG = dict(parallel=dict(moe=dict(size=4))) + + +def check_equal(A, B, atol=1e-06): + assert torch.allclose(A, B, rtol=0, atol=atol) is True + + +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # torch.set_printoptions(precision=30) + torch.backends.cuda.matmul.allow_tf32 = False + local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + torch.manual_seed(rs + local_rank) + moe_env.reset_loss() + tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + # print(f"tokens:\n{tokens}") + router = Top2Router(1) + layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity()) + if data_type == torch.float16: + layer = layer.half() + layer.cuda_mode = False + + old_out = layer(tokens) + # print(f"old output:\n{old_out}") + + ech = old_out.shape + grad = torch.randn(ech, device=get_current_device()) + old_out.backward(grad) + + o_tk_grad = tokens.grad.data.clone() + o_gt_grad = layer.gate.weight.grad.data.clone() + + tokens.grad.zero_() + layer.gate.weight.grad.zero_() + + layer.cuda_mode = True + new_out = layer(tokens) + + # print(torch.max(torch.abs(old_out - new_out))) + if data_type == torch.float32: + check_equal(old_out, new_out) + else: + check_equal(old_out, new_out, 1e-2) + # print(f"forward functions passed") + + # print(f"new output:\n{new_out}") + new_out.backward(grad) + n_tk_grad = tokens.grad.data.clone() + n_gt_grad = layer.gate.weight.grad.data.clone() + + # print(torch.max(torch.abs(o_tk_grad - n_tk_grad))) + if data_type == torch.float32: + check_equal(o_tk_grad, n_tk_grad) + else: + check_equal(o_tk_grad, o_tk_grad, 1e-2) + # print(f"tokens gradient passed") + + # print(torch.max(torch.abs(o_gt_grad - n_gt_grad))) + if data_type == torch.float32: + check_equal(o_gt_grad, n_gt_grad, 5e-05) + else: + check_equal(o_gt_grad, n_gt_grad, 2e-01) + # print(f"linear weight gradient passed") + + +@pytest.mark.skip(reason="Should be activated for detailed tests") +@pytest.mark.parametrize("rs", [2, 42, 60]) +@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048]) +@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) +def test_moe_top2(rs, hidden_size, data_type): + world_size = 4 + run_func = partial(run_routing, + world_size=world_size, + port=free_port(), + rs=rs, + hidden_size=hidden_size, + data_type=data_type) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_top2(2, 256, torch.float16)