diff --git a/colossalai/__init__.py b/colossalai/__init__.py index f3882e962..e7ea7d65a 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch, get_default_parser) __version__ = '0.0.1' - diff --git a/colossalai/builder/pipeline.py b/colossalai/builder/pipeline.py index 6027d34e6..3d14ce23e 100644 --- a/colossalai/builder/pipeline.py +++ b/colossalai/builder/pipeline.py @@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) module_list = [] for start, end in partitions[pipeline_rank]: - module_list.append( - nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end], - *[nn.Identity() for _ in range(len(layers) - end)])) + module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)], + *layers[start:end], + *[nn.Identity() for _ in range(len(layers) - end)])) if verbose: logger = get_dist_logger() logger.info(f'Total {len(layers)} layers', ranks=[0]) @@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' logger.info(log_str, ranks=[0]) return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] + \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index 22bec7e27..feb612f9f 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE */ #include "cpu_adam.h" - +#include #include +#include #include #include #include - -#include -#include #include #include @@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + if ((t + TILE) > rounded_size) + copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (_param_size > rounded_size) { for (size_t t = rounded_size; t < _param_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; + if ((t + TILE) > _param_size) + copy_size = _param_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + if ((t + TILE) > rounded_size) + copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3, s_optimizers[optimizer_id] = opt; if (should_log) { + std::string avx_type = ""; #if defined(__AVX512__) avx_type = "AVX512"; @@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + if ((t + TILE) > rounded_size) + copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for @@ -460,29 +463,43 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, grad_half_precision, loss_scale); } -int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2, - float epsilon, float weight_decay, bool bias_correction, - torch::Tensor ¶ms, torch::Tensor &grads, - torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq, - float loss_scale) { - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); +int adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + float loss_scale) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); - float *params_ptr = (float *)params_c.data_ptr(); - float *grads_ptr = (float *)grads_c.data_ptr(); - float *exp_avg_ptr = (float *)exp_avg_c.data_ptr(); - float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr(); - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, - params_c.numel(), (params.options().dtype() == at::kHalf), - (grads.options().dtype() == at::kHalf), loss_scale); + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + (params.options().dtype() == at::kHalf), + (grads.options().dtype() == at::kHalf), + loss_scale); - return 0; + return 0; } int destroy_adam_optimizer(int optimizer_id) { diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 74d2f1b17..023e653d3 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -48,10 +48,10 @@ SOFTWARE #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_LOAD_HALF(x) \ +#define SIMD_LOAD_HALF(x) \ _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm256_store_ps( \ +#define SIMD_STORE_HALF(x, d) \ + _mm256_store_ps( \ x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) @@ -66,8 +66,8 @@ SOFTWARE #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm_store_ps( \ +#define SIMD_STORE_HALF(x, d) \ + _mm_store_ps( \ x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) #endif @@ -83,25 +83,19 @@ union AVX_Data { #endif -#define STEP(SPAN) \ - void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ - float *_exp_avg_sq, size_t _param_size, \ - bool param_half_precision = false, \ +#define STEP(SPAN) \ + void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ + float *_exp_avg_sq, size_t _param_size, \ + bool param_half_precision = false, \ bool grad_half_precision = false, float loss_scale = -1); class Adam_Optimizer { - public: +public: Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, float eps = 1e-8, float weight_decay = 0, bool adamw_mode = true) - : _alpha(alpha), - _betta1(betta1), - _betta2(betta2), - _eps(eps), - _weight_decay(weight_decay), - _betta1_t(1.0), - _betta2_t(1.0), - _step(0), + : _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps), + _weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0), _adamw_mode(adamw_mode) {} ~Adam_Optimizer() {} @@ -141,7 +135,7 @@ class Adam_Optimizer { } } - private: +private: float _alpha; float _betta1; float _betta2; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu index 58d26235a..b12437870 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu @@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel( const int left_idx = block_start + threadIdx.x; const int right_idx = (blockIdx.x + 1) * vocab_size; float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[2] = {0.f, 0.f}; // logit and logit exp + float sum_logits[2] = {0.f, 0.f}; // logit and logit exp int target_tid = targets[blockIdx.x]; if (target_tid == padding_idx) { diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index 184106bd2..9ccf09d76 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -1,10 +1,10 @@ -#include - #include #include #include "kernels.h" +#include + namespace cg = cooperative_groups; curandStatePhilox4_32_10_t *curandstate; @@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; uint8_t m[4]; @@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; float4 *out4 = reinterpret_cast(out); const float4 *vals_float4 = reinterpret_cast(in); @@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel( const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel( } __syncthreads(); - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + for (int i = 1; i < 32; i <<= 1) + sum += g.shfl_down(sum, i); - if (y == 0) tile[0][x] = sum; + if (y == 0) + tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { @@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel( } __syncthreads(); - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) + sum += g.shfl_down(sum, i); - if (y == 0) tile[0][x] = sum; + if (y == 0) + tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { @@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel( const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel( float sum = tile[threadIdx.y][threadIdx.x]; __syncthreads(); - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) + sum += g.shfl_down(sum, i); - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + if (threadIdx.x == 0) + tile[0][threadIdx.y] = sum; __syncthreads(); if (threadIdx.y == 0) { diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu index bc90c54c0..e37bc3d04 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -1,7 +1,7 @@ -#include - #include "kernels.h" +#include + namespace cg = cooperative_groups; /** diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h index 38103c173..c58ed44ba 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h @@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f; const float REDUCE_FLOAT_INF_POS = 100000000.f; const unsigned int WARP_REDUCE_SIZE = 32; -template -__forceinline__ __device__ T warpReduceSum(T val) { +template __forceinline__ __device__ T warpReduceSum(T val) { for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); return val; } /* Calculate the sum of all elements in a block */ -template -__forceinline__ __device__ T blockReduceSum(T val) { +template __forceinline__ __device__ T blockReduceSum(T val) { static __shared__ T shared[32]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; val = warpReduceSum(val); - if (lane == 0) shared[wid] = val; + if (lane == 0) + shared[wid] = val; __syncthreads(); val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; @@ -57,10 +56,10 @@ __inline__ __device__ void warpReduce(float *pval) { template <> __inline__ __device__ void warpReduce(float *pval) { float val0_tmp, val1_tmp; -#define WarpReduceMaxOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval) = max(val0_tmp, *(pval)); \ +#define WarpReduceMaxOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval) = max(val0_tmp, *(pval)); \ *(pval + 1) = max(val1_tmp, *(pval + 1)); WarpReduceMaxOneStep(16, 32); @@ -89,10 +88,10 @@ __inline__ __device__ void warpReduce(float *pval) { template <> __inline__ __device__ void warpReduce(float *pval) { float val0_tmp, val1_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval + 0) += val0_tmp; \ +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval + 0) += val0_tmp; \ *(pval + 1) += val1_tmp WarpReduceSumOneStep(16, 32); @@ -107,14 +106,14 @@ __inline__ __device__ void warpReduce(float *pval) { template <> __inline__ __device__ void warpReduce(float *pval) { float val0_tmp, val1_tmp, val2_tmp, val3_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ - val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ - *(pval + 0) += val0_tmp; \ - *(pval + 1) += val1_tmp; \ - *(pval + 2) += val2_tmp; \ +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ + val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp; \ + *(pval + 2) += val2_tmp; \ *(pval + 3) += val3_tmp WarpReduceSumOneStep(16, 32); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h index f7d75f38c..dc80881f9 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h @@ -9,7 +9,7 @@ #include "cuda_util.h" class Context { - public: +public: Context() : _stream(nullptr) { CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); } @@ -30,7 +30,7 @@ class Context { cublasHandle_t get_cublashandle() { return _cublasHandle; } - private: +private: cudaStream_t _stream; cublasHandle_t _cublasHandle; }; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h index f4e9befc6..af7c9c04d 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h @@ -8,9 +8,8 @@ #include "cuda_util.h" -template -class CrossEntropyLayer { - public: +template class CrossEntropyLayer { +public: CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); virtual ~CrossEntropyLayer(); @@ -23,7 +22,7 @@ class CrossEntropyLayer { void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); - private: +private: void allocate_mem_buffer() { // allocate local gpu memory _loss_buffer = cuda_malloc(_max_batch_tokens * 2); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h index 1595257be..bc2258762 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h @@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file, template void print_vec(const T *outv, std::string outn, int num_output_ele); -template -T *cuda_malloc(size_t ele_num); +template T *cuda_malloc(size_t ele_num); void cuda_free(void *pdata); @@ -29,6 +28,6 @@ template void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, std::string file, int line, cudaStream_t stream); -#define CHECK_NAN_INF(ptr, size, stream) \ - check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ +#define CHECK_NAN_INF(ptr, size, stream) \ + check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h index 563a7fe28..c2a4f7c20 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -3,14 +3,12 @@ #include #include #include - #include #include "kernels.h" -template -class Dropout { - public: +template class Dropout { +public: struct Config { float ratio; bool training; @@ -90,7 +88,7 @@ class Dropout { void SetTrainingMode(bool training) { _config.training = training; } - private: +private: uint8_t *_mask; Config _config; }; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h index ec963259f..9a43aeec3 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h @@ -13,16 +13,14 @@ #include "cublas_wrappers.h" #include "kernels.h" -template -class FeedForward { - public: +template class FeedForward { +public: struct Config { int outputSize; int inputSize; std::array gemm_algos; Config(int outputs, int inputs) - : outputSize(outputs), - inputSize(inputs), + : outputSize(outputs), inputSize(inputs), gemm_algos(std::array({99, 99, 99})) {} }; @@ -63,6 +61,6 @@ class FeedForward { config_.inputSize = inputSize; } - private: +private: Config config_; }; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h index ec447ad84..005a36ba1 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -10,9 +10,8 @@ using namespace std; -template -class Softmax { - public: +template class Softmax { +public: struct Config { size_t nhead; Config(size_t nhead) : nhead(nhead) {} @@ -37,6 +36,6 @@ class Softmax { void reset_size(size_t nhead) { config_.nhead = nhead; } - private: +private: Config config_; }; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu index 3e61d4e35..5eec0d662 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -1,5 +1,6 @@ #include "block_reduce.h" #include "kernels.h" + #include namespace cg = cooperative_groups; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu index 98af433fe..64f0fc2c2 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -1,4 +1,3 @@ -#include #include #include @@ -7,6 +6,8 @@ #include "block_reduce.h" #include "kernels.h" +#include + namespace cg = cooperative_groups; const float EPSILON = 1e-8f; @@ -119,7 +120,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], to_len); } - } // blockIdx.x + } // blockIdx.x } template @@ -197,7 +198,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], to_len); } - } // blockIdx.x + } // blockIdx.x } /* @@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) + sum += g.shfl_xor(sum, i); #pragma unroll for (int i = 0; i < ITERATIONS; ++i) { diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp index 4690277e6..8571f5f71 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -2,12 +2,10 @@ * https://github.com/NVIDIA/apex * with minor changes. */ -#include - -#include -#include - #include "compat.h" +#include +#include +#include namespace { @@ -67,7 +65,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape, check_args(input, normalized_shape, n1, n2); check_args(normalized_shape, gamma, beta); } -} // namespace +} // namespace void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, at::Tensor *input, int n1, int n2, @@ -75,16 +73,17 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, at::Tensor *beta, double epsilon); #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ +#define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) std::vector layer_norm_affine(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, double epsilon) { + CHECK_INPUT(input); CHECK_INPUT(gamma); CHECK_INPUT(beta); @@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, double epsilon, at::Tensor *grad_input, at::Tensor *grad_gamma, at::Tensor *grad_beta); -std::vector layer_norm_gradient_affine( - at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, - at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, - double epsilon) { +std::vector +layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar, + at::Tensor input, at::IntArrayRef normalized_shape, + at::Tensor gamma, at::Tensor beta, double epsilon) { + CHECK_INPUT(dout); CHECK_INPUT(mean); CHECK_INPUT(invvar); diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp index 61c8a7250..6aaa15b4e 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -15,24 +15,25 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx); -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); +std::vector +moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, + torch::Tensor mask, torch::Tensor dest_idx); torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); -#define CHECK_CUDA(x) \ +#define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ +#define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) torch::Tensor moe_dispatch_forward(int s, int ec, int h, torch::Tensor batch_tokens, torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); CHECK_CUDA(mask); CHECK_CUDA(dest_idx); @@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h, torch::Tensor expert_grad, torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); CHECK_CUDA(mask); CHECK_CUDA(dest_idx); @@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); CHECK_INPUT(logits); CHECK_CUDA(mask); @@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, dest_idx); } -std::vector moe_combine_backward(int s, int e, int c, int h, - torch::Tensor tokens_grad, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx) { +std::vector +moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); CHECK_INPUT(logits); CHECK_CUDA(mask); diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu index 0454377a2..ac7f8aba2 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -1,13 +1,12 @@ +#include "block_reduce.h" +#include #include #include #include -#include - -#include "block_reduce.h" - template __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { template __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { template __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, template __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, template __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, template __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, blockReduce(&thread_sum); - if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); + if (threadIdx.x == 0) + *weight_grad = static_cast(thread_sum); } template __device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, const T weight1, const T weight2, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2, const T weight1, const T weight2, const int cols) { + assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -244,6 +251,7 @@ template __device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, const int cols, const int indicator1, const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) moe_dpch_two_fwd(src_row, dst_row1, dst_row2, cols); @@ -259,6 +267,7 @@ template __device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, const int cols, const int indicator1, const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) moe_dpch_two_bwd(src_row, dst_row1, dst_row2, cols); @@ -274,6 +283,7 @@ template __global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, int *mask1, int *mask2, int *dest1, int *dest2, const int h) { + int row = blockIdx.x; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_dpch_fwd_selector( @@ -285,6 +295,7 @@ template __global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int h) { + int row = blockIdx.x; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_dpch_bwd_selector( @@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, const int cols, const T weight1, const T weight2, const int indicator1, const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) moe_cb_two_fwd(src_row1, src_row2, dst_row, weight1, weight2, cols); @@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, T *wt_grad1, T *wt_grad2, const T weight1, const T weight2, const int indicator1, const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) moe_cb_two_bwd(src_row1, src_row2, dst_row, tks_row1, tks_row2, wt_grad1, @@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, T *logits, int *mask1, int *mask2, int *dest1, int *dest2, const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; T *row_log = logits + (row * e); @@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, T *logits, T *logits_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); @@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, template __global__ void cumsum_kernel(int *inputs, int *outputs, const int s, const int e) { + assert(s % pack_size == 0); constexpr int bpack_size = block_size * pack_size; int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; @@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s, } __syncthreads(); - if (tid == 0) temp[0] = temp[block_size]; + if (tid == 0) + temp[0] = temp[block_size]; __syncthreads(); if (idx + tps < s) { @@ -436,6 +453,7 @@ template void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int h) { + if (h < 256) moe_dpch_fwd_kernel <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); @@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, template void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int h) { + if (h < 256) moe_dpch_bwd_kernel <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); @@ -477,6 +496,7 @@ template void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int e, const int c, const int h) { + if (h < 256) moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, @@ -504,11 +524,12 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, T *logits_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int e, const int c, const int h) { + if (h < 256) moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); - else // if (h < 512) + else // if (h < 512) moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); @@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, } void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); else if (s <= 512) @@ -537,26 +559,27 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { // API FUNCTIONS -------------------------------- -#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented yet for specific data type."); \ +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ } torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, torch::Tensor batch_tokens, torch::Tensor mask, torch::Tensor dest_idx) { + assert(h % 16 == 0); auto res = torch::zeros( {ec, h}, @@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, torch::Tensor expert_grad, torch::Tensor mask, torch::Tensor dest_idx) { + assert(h % 16 == 0); auto res = torch::zeros( {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); @@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx) { + assert(h % 16 == 0); assert(expert_tokens.dtype() == logits.dtype()); @@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, return res; } -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { +std::vector +moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, + torch::Tensor mask, torch::Tensor dest_idx) { + assert(h % 16 == 0); assert(tokens_grad.dtype() == expert_tokens.dtype()); assert(expert_tokens.dtype() == logits.dtype()); @@ -647,6 +673,7 @@ std::vector moe_combine_cuda_backward( } torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); assert(mask.dtype() == torch::kInt32); diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu index 49ab83e8f..8686e83f8 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -16,8 +16,7 @@ #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T *p) { +template __device__ __forceinline__ bool is_aligned(T *p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } @@ -29,12 +28,11 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; } -template -struct L2NormFunctor { - __device__ __forceinline__ void operator()( - int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, - float *output, float *output_per_tensor, bool per_tensor, - int max_chunks_per_tensor) { +template struct L2NormFunctor { + __device__ __forceinline__ void + operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -50,8 +48,8 @@ struct L2NormFunctor { __shared__ float s_vals[512]; - float vals[ILP]; // = {0}; // this probably works too but I want to be - // sure... + float + vals[ILP]; // = {0}; // this probably works too but I want to be sure... x_t r_x[ILP]; for (int i = 0; i < ILP; i++) { vals[i] = 0.f; @@ -86,14 +84,15 @@ struct L2NormFunctor { } float val = 0.f; - for (int i = 0; i < ILP; i++) val += vals[i]; + for (int i = 0; i < ILP; i++) + val += vals[i]; float final = reduce_block_into_lanes(s_vals, val); if (threadIdx.x == 0) { if (!isfinite(final)) *noop_gmem = - 1; // Blindly fire off a write. These will race but that's ok. + 1; // Blindly fire off a write. These will race but that's ok. output[blockIdx.x] += final; if (per_tensor) output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * @@ -105,12 +104,11 @@ struct L2NormFunctor { // Probably better to template, but since we are not likely to support other // norm -template -struct MaxNormFunctor { - __device__ __forceinline__ void operator()( - int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, - float *output, float *output_per_tensor, bool per_tensor, - int max_chunks_per_tensor) { +template struct MaxNormFunctor { + __device__ __forceinline__ void + operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -126,8 +124,8 @@ struct MaxNormFunctor { __shared__ float s_vals[512]; - float vals[ILP]; // = {0}; // this probably works too but I want to be - // sure... + float + vals[ILP]; // = {0}; // this probably works too but I want to be sure... x_t r_x[ILP]; for (int i = 0; i < ILP; i++) { vals[i] = 0.f; @@ -162,14 +160,15 @@ struct MaxNormFunctor { } float val = 0.f; - for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i])); + for (int i = 0; i < ILP; i++) + val = fmaxf(fabsf(val), fabsf(vals[i])); float final = reduce_block_into_lanes_max_op(s_vals, val); if (threadIdx.x == 0) { if (!isfinite(final)) *noop_gmem = - 1; // Blindly fire off a write. These will race but that's ok. + 1; // Blindly fire off a write. These will race but that's ok. output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); if (per_tensor) output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * @@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret, if (blockIdx.x == 0) { float val = 0; - if (threadIdx.x < 320) val = output[threadIdx.x]; + if (threadIdx.x < 320) + val = output[threadIdx.x]; float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) *ret = sqrt(final); + if (threadIdx.x == 0) + *ret = sqrt(final); } if (per_tensor) { @@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret, float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final); + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = sqrt(final); } } @@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, if (blockIdx.x == 0) { float val = 0; - if (threadIdx.x < 320) val = output[threadIdx.x]; + if (threadIdx.x < 320) + val = output[threadIdx.x]; if (norm_type == 0) { float final = reduce_block_into_lanes_max_op(vals, val); - if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final; + if (threadIdx.x == 0) + *ret = alpha * (*ret) + beta * final; } else { float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); + if (threadIdx.x == 0) + *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); } } @@ -255,10 +260,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, } } -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python) { +std::tuple +multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu index 54c422019..15ac20914 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -15,8 +15,7 @@ #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T *p) { +template __device__ __forceinline__ bool is_aligned(T *p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } @@ -29,25 +28,24 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, } typedef enum { - MOMENT_MODE_0 = 0, // L2 regularization mode - MOMENT_MODE_1 = 1 // Decoupled weight decay mode + MOMENT_MODE_0 = 0, // L2 regularization mode + MOMENT_MODE_1 = 1 // Decoupled weight decay mode } adamMode_t; -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); +std::tuple +multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python); using MATH_T = float; -template -struct LAMBStage1Functor { - __device__ __forceinline__ void operator()( - int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, - const float beta1, const float beta2, const float beta3, - const float beta1_correction, const float beta2_correction, - const float epsilon, adamMode_t mode, const float decay, - const float *global_grad_norm, const float max_global_grad_norm) { +template struct LAMBStage1Functor { + __device__ __forceinline__ void + operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, + const float beta1, const float beta2, const float beta3, + const float beta1_correction, const float beta2_correction, + const float epsilon, adamMode_t mode, const float decay, + const float *global_grad_norm, const float max_global_grad_norm) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -91,7 +89,8 @@ struct LAMBStage1Functor { i_start += blockDim.x) { // load load_store(l_g, g, 0, i_start); - if (decay != 0) load_store(l_p, p, 0, i_start); + if (decay != 0) + load_store(l_p, p, 0, i_start); load_store(l_m, m, 0, i_start); load_store(l_v, v, 0, i_start); // unpack @@ -205,12 +204,12 @@ struct LAMBStage1Functor { // Step 2 reads in 'update' value and per-tensor param_norm and update_norm. // It computes new parameter value. -template -struct LAMBStage2Functor { - __device__ __forceinline__ void operator()( - int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, - const float *per_tensor_param_norm, const float *per_tensor_update_norm, - const float learning_rate, const float decay, bool use_nvlamb) { +template struct LAMBStage2Functor { + __device__ __forceinline__ void + operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, + const float *per_tensor_param_norm, + const float *per_tensor_update_norm, const float learning_rate, + const float decay, bool use_nvlamb) { // I'd like this kernel to propagate infs/nans. // if(*noop_gmem == 1) // return; @@ -311,7 +310,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, // Handle grad averaging mode float beta3 = 1.0f; - if (grad_averaging == 1) beta3 = 1 - beta1; + if (grad_averaging == 1) + beta3 = 1 - beta1; std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1); @@ -330,7 +330,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, LAMBStage1Functor(), beta1, beta2, - beta3, // 1-beta1 or 1 depends on averaging mode + beta3, // 1-beta1 or 1 depends on averaging mode bias_correction1, bias_correction2, epsilon, (adamMode_t)mode, weight_decay, global_grad_norm.DATA_PTR(), max_grad_norm);) diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu index 360485dcd..98161792e 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu @@ -15,8 +15,7 @@ #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T *p) { +template __device__ __forceinline__ bool is_aligned(T *p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } @@ -28,8 +27,7 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; } -template -struct ScaleFunctor { +template struct ScaleFunctor { __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl, @@ -78,7 +76,8 @@ struct ScaleFunctor { for (int ii = 0; ii < ILP; ii++) { r_in[ii] = 0; int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) r_in[ii] = in[i]; + if (i < n && i < chunk_size) + r_in[ii] = in[i]; } // note for clarification to future michael: // From a pure memory dependency perspective, there's likely no point @@ -94,13 +93,14 @@ struct ScaleFunctor { #pragma unroll for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) out[i] = r_out[ii]; + if (i < n && i < chunk_size) + out[i] = r_out[ii]; } } } if (!finite) *noop_gmem = - 1; // Blindly fire off a write. These will race but that's ok. + 1; // Blindly fire off a write. These will race but that's ok. } }; diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu index a077bc738..bc30e2722 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -1,15 +1,14 @@ -// modified from -// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu +// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu #include #include #include #include +#include "multi_tensor_apply.cuh" +#include "compat.h" + #include #include -#include "compat.h" -#include "multi_tensor_apply.cuh" - #define BLOCK_SIZE 512 #define ILP 4 @@ -29,53 +28,69 @@ * wd_after_momentum : apply weight decay _after_ momentum instead of before **/ template -struct SGDFunctor { - __device__ __forceinline__ void operator()( - int chunk_size, volatile int *noop_gmem, TensorListMetadata &tl, - float wd, float momentum, float dampening, float lr, bool nesterov, - bool first_run, bool wd_after_momentum, float scale) { - // Early exit if we don't need to do anything - if (*noop_gmem) return; +struct SGDFunctor +{ + __device__ __forceinline__ void operator()( + int chunk_size, + volatile int *noop_gmem, + TensorListMetadata &tl, + float wd, + float momentum, + float dampening, + float lr, + bool nesterov, + bool first_run, + bool wd_after_momentum, + float scale) + { + // Early exit if we don't need to do anything + if (*noop_gmem) + return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; - T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; - grad_in += chunk_idx * chunk_size; + T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; + grad_in += chunk_idx * chunk_size; - T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; - weight_in += chunk_idx * chunk_size; + T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; + weight_in += chunk_idx * chunk_size; - T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; - mom_in += chunk_idx * chunk_size; + T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; + mom_in += chunk_idx * chunk_size; - at::Half *model_weights_out = nullptr; - if (N == 4) { - model_weights_out = (at::Half *)tl.addresses[3][tensor_loc]; - model_weights_out += chunk_idx * chunk_size; - } - - n -= chunk_idx * chunk_size; - - // Non-divergent exit condition for the __syncthreads - float incoming_grads[ILP]; - float incoming_weights[ILP]; - float incoming_moms[ILP]; - for (int i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * ILP) { -#pragma unroll - for (int ii = 0; ii < ILP; ii++) { - incoming_grads[ii] = 0; - incoming_weights[ii] = 0; - incoming_moms[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) { - incoming_grads[ii] = static_cast(grad_in[i]) * scale; - incoming_weights[ii] = static_cast(weight_in[i]); - incoming_moms[ii] = static_cast(mom_in[i]); + at::Half *model_weights_out = nullptr; + if (N == 4) + { + model_weights_out = (at::Half *)tl.addresses[3][tensor_loc]; + model_weights_out += chunk_idx * chunk_size; } - } + + n -= chunk_idx * chunk_size; + + // Non-divergent exit condition for the __syncthreads + float incoming_grads[ILP]; + float incoming_weights[ILP]; + float incoming_moms[ILP]; + for (int i_start = 0; + i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) + { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) + { + incoming_grads[ii] = 0; + incoming_weights[ii] = 0; + incoming_moms[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) + { + incoming_grads[ii] = static_cast(grad_in[i]) * scale; + incoming_weights[ii] = static_cast(weight_in[i]); + incoming_moms[ii] = static_cast(mom_in[i]); + } + } // note for clarification to future michael: // From a pure memory dependency perspective, there's likely no point unrolling @@ -83,128 +98,185 @@ struct SGDFunctor { // Put another way, the STGs are dependent on the LDGs, but not on each other. // There is still compute ILP benefit from unrolling the loop though. #pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) { - // apply weight decay before momentum if necessary - if (wd != 0.f && !wd_after_momentum) - incoming_grads[ii] += wd * incoming_weights[ii]; + for (int ii = 0; ii < ILP; ii++) + { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) + { + // apply weight decay before momentum if necessary + if (wd != 0.f && !wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; - if (momentum != 0.f) { - if (!first_run) - incoming_moms[ii] = incoming_moms[ii] * momentum + - (1.f - dampening) * incoming_grads[ii]; - else // initialize momentums to current incoming grads - incoming_moms[ii] = incoming_grads[ii]; + if (momentum != 0.f) + { + if (!first_run) + incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; + else // initialize momentums to current incoming grads + incoming_moms[ii] = incoming_grads[ii]; - if (nesterov) - incoming_grads[ii] += momentum * incoming_moms[ii]; - else - incoming_grads[ii] = incoming_moms[ii]; - } + if (nesterov) + incoming_grads[ii] += momentum * incoming_moms[ii]; + else + incoming_grads[ii] = incoming_moms[ii]; + } - // Apply WD after momentum if desired - if (wd != 0.f && wd_after_momentum) - incoming_grads[ii] += wd * incoming_weights[ii]; + // Apply WD after momentum if desired + if (wd != 0.f && wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; - // adjust the weight and write out - weight_in[i] += (-lr * incoming_grads[ii]); + // adjust the weight and write out + weight_in[i] += (-lr * incoming_grads[ii]); - // if necessary, write out an fp16 copy of the weights - if (N == 4) - model_weights_out[i] = static_cast(weight_in[i]); + // if necessary, write out an fp16 copy of the weights + if (N == 4) + model_weights_out[i] = static_cast(weight_in[i]); - // also write out the new momentum - if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; + // also write out the new momentum + if (momentum != 0.f) + mom_in[i] = incoming_moms[ii]; + } + } } - } } - } }; -void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - float wd, float momentum, float dampening, float lr, - bool nesterov, bool first_run, - bool wd_after_momentum, float scale) { - auto num_tensors = tensor_lists.size(); - auto grad_type = tensor_lists[0][0].scalar_type(); - auto weight_type = tensor_lists[1][0].scalar_type(); +void multi_tensor_sgd_cuda( + int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, + float momentum, + float dampening, + float lr, + bool nesterov, + bool first_run, + bool wd_after_momentum, + float scale) +{ + auto num_tensors = tensor_lists.size(); + auto grad_type = tensor_lists[0][0].scalar_type(); + auto weight_type = tensor_lists[1][0].scalar_type(); - if (num_tensors == 4) - for (int i = 0; i < tensor_lists[3].size(); i++) - TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, - "Additional output tensors should always be fp16."); + if (num_tensors == 4) + for (int i = 0; i < tensor_lists[3].size(); i++) + TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, + "Additional output tensors should always be fp16."); - TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), - "expected noop flag to be on the same device as tensors"); + TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); - // We have 3 possibilities to handle here, in terms of - // grad_type, param_type, momentum_type, requires_fp16_copy - // 1. fp16, fp16, fp16, No - // 2. fp32, fp32, fp32, No - // 3. fp16, fp32, fp32, Yes - // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case - // It's easier to hardcode these possibilities than to use - // switches etc. to handle the cross-product of cases where - // we don't want the majority of them. + // We have 3 possibilities to handle here, in terms of + // grad_type, param_type, momentum_type, requires_fp16_copy + // 1. fp16, fp16, fp16, No + // 2. fp32, fp32, fp32, No + // 3. fp16, fp32, fp32, Yes + // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case + // It's easier to hardcode these possibilities than to use + // switches etc. to handle the cross-product of cases where + // we don't want the majority of them. - // Case 1. fp16, fp16, fp16, No - if (grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Half && num_tensors == 3) { - multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, at::Half, at::Half>(), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, - scale); - } - // Case 2. fp16, fp32, fp32, No - // else if (grad_type == at::ScalarType::Half && - // weight_type == at::ScalarType::Float && - // num_tensors == 3) { - // multi_tensor_apply<3>( - // BLOCK_SIZE, - // chunk_size, - // noop_flag, - // tensor_lists, - // SGDFunctor<3, at::Half, float>(), - // wd, - // momentum, - // dampening, - // lr, - // nesterov, - // first_run, - // wd_after_momentum); - // } - // Case 2. fp32, fp32, fp32, No - else if (grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && num_tensors == 3) { - multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, float, float>(), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, - scale); - } - // Case 3. fp16, fp32, fp32, Yes - else if (grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Float && num_tensors == 4) { - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, at::Half, float>(), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, - scale); - } - // Case 4. fp32, fp32, fp32, Yes - else if (grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && num_tensors == 4) { - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, float, float>(), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, - scale); - } else { - AT_ERROR( - "multi_tensor_sgd only supports some combinations of gradient & weight " - "types. Given: ", - "gradient: ", grad_type, ", weight: ", weight_type, - ", num_lists: ", num_tensors); - } + // Case 1. fp16, fp16, fp16, No + if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Half && + num_tensors == 3) + { + multi_tensor_apply<3>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + SGDFunctor<3, at::Half, at::Half>(), + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale); + } + // Case 2. fp16, fp32, fp32, No + // else if (grad_type == at::ScalarType::Half && + // weight_type == at::ScalarType::Float && + // num_tensors == 3) { + // multi_tensor_apply<3>( + // BLOCK_SIZE, + // chunk_size, + // noop_flag, + // tensor_lists, + // SGDFunctor<3, at::Half, float>(), + // wd, + // momentum, + // dampening, + // lr, + // nesterov, + // first_run, + // wd_after_momentum); + // } + // Case 2. fp32, fp32, fp32, No + else if (grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && + num_tensors == 3) + { + multi_tensor_apply<3>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + SGDFunctor<3, float, float>(), + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale); + } + // Case 3. fp16, fp32, fp32, Yes + else if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Float && + num_tensors == 4) + { + multi_tensor_apply<4>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + SGDFunctor<4, at::Half, float>(), + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale); + } + // Case 4. fp32, fp32, fp32, Yes + else if (grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && + num_tensors == 4) + { + multi_tensor_apply<4>( + BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + SGDFunctor<4, float, float>(), + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale); + } + else + { + AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", + "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); + } - AT_CUDA_CHECK(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); } \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp index b02556f79..63bf633f5 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -10,9 +10,8 @@ #include "kernels.h" template -MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_size, - int num_heads, +MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len, + int hidden_size, int num_heads, float attn_prob_dropout_ratio, float hidden_output_dropout_ratio, bool pre_or_postLayerNorm) @@ -23,22 +22,18 @@ MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, _heads(num_heads), _training(true), _pre_or_postLayerNorm(pre_or_postLayerNorm), - _qkv_linear( - typename FeedForward::Config(3 * hidden_size, hidden_size)), - _attn_out_linear( - typename FeedForward::Config(hidden_size, hidden_size)), - _attn_ln(typename Normalize_Layer::Config(hidden_size, false), - _max_batch_tokens), + _qkv_linear(typename FeedForward::Config(3 * hidden_size, hidden_size)), + _attn_out_linear(typename FeedForward::Config(hidden_size, hidden_size)), + _attn_ln(typename Normalize_Layer::Config(hidden_size, false), _max_batch_tokens), _softmax(typename Softmax::Config(num_heads)), _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), _max_batch_tokens * _heads * _max_seq_len), _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), _max_batch_tokens * _hidden_size), - _attn_scores(typename StridedBatchGemm::Config( - (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, - CUBLAS_OP_N)), - _attn_context(typename StridedBatchGemm::Config( - T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { + _attn_scores(typename StridedBatchGemm::Config((T(1.0) / T(sqrt(_hidden_size / _heads))), + T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)), + _attn_context( + typename StridedBatchGemm::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { assert(_hidden_size % _heads == 0); } @@ -48,52 +43,43 @@ MultiHeadAttention::~MultiHeadAttention() { } template -void MultiHeadAttention::attn_layer_fw(const T *input_ptr, - const T *input_mask_ptr, +void MultiHeadAttention::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer) { T *q_tf_ptr = _qkv_ptr; T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; if (_pre_or_postLayerNorm) { - _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); + _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, + _stream); } - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, - _cublasHandle); + _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle); - launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, - _batch_size, _seq_len, 3, _heads / pg_size, - _hidden_size / _heads, _stream); + launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3, + _heads / pg_size, _hidden_size / _heads, _stream); // attention scores, q*k - _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle); + _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle); // Softmax + Mask _softmax.reset_size(_heads / pg_size); - _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, - _seq_len, _stream, true); + _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true); // attn prob dropout. - _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - // attention context, score * v - _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, - _cublasHandle); - - // [b, nh, s, ad] -> [b, s, nh, ad] - launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, - _hidden_size / pg_size, _heads / pg_size, 1, + _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len, _stream); + // attention context, score * v + _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle); + + // [b, nh, s, ad] -> [b, s, nh, ad] + launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size, + _heads / pg_size, 1, _stream); + _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, - output_ptr, _cublasHandle); + _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle); // allreduce if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { @@ -102,27 +88,24 @@ void MultiHeadAttention::attn_layer_fw(const T *input_ptr, if (typeid(T) != typeid(float)) { data_type = torch::kHalf; } - auto output_tensor = torch::from_blob( - output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); + auto output_tensor = + torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); std::vector allreduce_tensors = {output_tensor}; auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); work->wait(); } - _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, - _attn_ob_ptr, _batch_tokens, _hidden_size, - _stream); + _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr, + _batch_tokens, _hidden_size, _stream); if (!_pre_or_postLayerNorm) { // in-place ln since ln-input will not be used in post-ln mode - _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); + _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream); } } template -void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, - T *out_ptr) { +void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) { _stream = Context::Instance().get_stream(); _cublasHandle = Context::Instance().get_cublashandle(); T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim @@ -131,11 +114,8 @@ void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, } template -void MultiHeadAttention::attn_layer_bw(const T *input_ptr, - const T *input_mask_ptr, - const T *output_ptr, - const T *grad_output_ptr, - T *grad_input_ptr, T *buffer) { +void MultiHeadAttention::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, + const T *grad_output_ptr, T *grad_input_ptr, T *buffer) { cudaStream_t streams[2] = {_stream, _stream}; const T *q_tf_ptr = _qkv_ptr; @@ -157,57 +137,45 @@ void MultiHeadAttention::attn_layer_bw(const T *input_ptr, // batch_size * head_num * seq_len * seq_len); if (_pre_or_postLayerNorm) { - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_output_ptr, _batch_tokens, - _hidden_size, _stream); + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr, + _batch_tokens, _hidden_size, _stream); } else { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, - grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, - _attn_nb_ptr, _batch_tokens, streams); - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_residual_ptr, _batch_tokens, - _hidden_size, _stream); + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr, + nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr, + _batch_tokens, _hidden_size, _stream); } // bw of output project _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, - _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - false); - launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - _stream); + _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr, + _grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream, + grad_input_buf_ptr, nullptr, false); + launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len, + _hidden_size / pg_size, _heads / pg_size, _stream); // bw of score * v - _attn_context.Backward( - _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, - grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); + _attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, + grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); - _attn_prob_dropout.d_dropout(grad_softmax_ptr, - _batch_heads * _seq_len * _seq_len, _stream); + _attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream); _softmax.reset_size(_heads / pg_size); - _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, - _seq_len, _stream); + _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream); // bw of q * k - _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, - grad_qkv_5d_ptr); + _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle, + grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr); // [3, b, nh, s, ad] -> [b, s, 3, h] - launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - 3, _stream); + launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len, + _hidden_size / pg_size, _heads / pg_size, 3, _stream); - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, - _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - true); + _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr, + _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream, + grad_input_buf_ptr, nullptr, true); // allreduce if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { @@ -217,8 +185,7 @@ void MultiHeadAttention::attn_layer_bw(const T *input_ptr, data_type = torch::kHalf; } auto grad_input_tensor = - torch::from_blob(grad_input_buf_ptr, - {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, torch::TensorOptions(torch::kCUDA).dtype(data_type)); std::vector allreduce_tensors = {grad_input_tensor}; auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); @@ -226,21 +193,19 @@ void MultiHeadAttention::attn_layer_bw(const T *input_ptr, } if (_pre_or_postLayerNorm) { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, - grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, - _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr, + grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, + streams); } else { // FIXME later - launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, - _batch_size, _seq_len, _hidden_size, _stream); + launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size, + _seq_len, _hidden_size, _stream); } } template -void MultiHeadAttention::Backward(const T *grad_output_ptr, - const T *input_ptr, const T *output_ptr, - const T *input_mask_ptr, - T *grad_input_ptr) { +void MultiHeadAttention::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, + const T *input_mask_ptr, T *grad_input_ptr) { _stream = Context::Instance().get_stream(); _cublasHandle = Context::Instance().get_cublashandle(); T *buffer = _shared_mem_ptr; @@ -250,8 +215,7 @@ void MultiHeadAttention::Backward(const T *grad_output_ptr, 4 * _batch_dim + max(3 * _batch_dim, _batch_size * _head_num * _seq_len * _seq_len); */ - attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, - grad_input_ptr, buffer); + attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer); } template @@ -269,8 +233,7 @@ template class MultiHeadAttention<__half>; // x is torch::Tensor #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) @@ -278,17 +241,15 @@ template class MultiHeadAttention<__half>; static std::unordered_map> s_multihead_attention; template -int create_multihead_attention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_dim, int num_heads, - float attn_prob_dropout_ratio, - float hidden_dropout_ratio, - bool pre_or_postLayerNorm, +int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim, + int num_heads, float attn_prob_dropout_ratio, + float hidden_dropout_ratio, bool pre_or_postLayerNorm, c10::intrusive_ptr pg_) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); Context::Instance().set_stream(stream); auto layer = std::make_shared>( - layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, - attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); + layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio, + hidden_dropout_ratio, pre_or_postLayerNorm); layer->SetPG(pg_); @@ -300,12 +261,15 @@ int create_multihead_attention(int layer_id, int max_batch_tokens, } template -std::vector multihead_attention_fw( - int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, - const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, - const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, - const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, - bool training_mode, bool prelayernorm) { +std::vector multihead_attention_fw(int layer_id, const torch::Tensor &input, + const torch::Tensor &input_mask, + const torch::Tensor &in_proj_weight, + const torch::Tensor &in_proj_bias, + const torch::Tensor &out_proj_weight, + const torch::Tensor &out_proj_bias, + const torch::Tensor &norm_weight, + const torch::Tensor &norm_bias, + bool training_mode, bool prelayernorm) { CHECK_INPUT(input); CHECK_INPUT(input_mask); @@ -316,8 +280,7 @@ std::vector multihead_attention_fw( T *out_ptr = (T *)output.data_ptr(); std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); + std::static_pointer_cast>(s_multihead_attention[layer_id]); layer->set_cur_batch_shape(input.size(0), input.size(1)); layer->SetTrainingMode(training_mode); @@ -334,13 +297,17 @@ std::vector multihead_attention_fw( } template -std::vector multihead_attention_bw( - int layer_id, const torch::Tensor &grad_dec_output, - const torch::Tensor &output, const torch::Tensor &input, - const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, - const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, - const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, - const torch::Tensor &norm_bias) { +std::vector multihead_attention_bw(int layer_id, + const torch::Tensor &grad_dec_output, + const torch::Tensor &output, + const torch::Tensor &input, + const torch::Tensor &input_mask, + const torch::Tensor &in_proj_weight, + const torch::Tensor &in_proj_bias, + const torch::Tensor &out_proj_weight, + const torch::Tensor &out_proj_bias, + const torch::Tensor &norm_weight, + const torch::Tensor &norm_bias) { auto g_output = grad_dec_output.contiguous(); CHECK_INPUT(g_output); CHECK_INPUT(output); @@ -365,8 +332,7 @@ std::vector multihead_attention_bw( T *grad_input_ptr = (T *)grad_input.data_ptr(); std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); + std::static_pointer_cast>(s_multihead_attention[layer_id]); layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); @@ -376,12 +342,10 @@ std::vector multihead_attention_bw( layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); - layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, - grad_input_ptr); + layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr); - return {grad_input, grad_in_proj_weight, grad_in_proj_bias, - grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, - grad_norm_bias}; + return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, + grad_out_proj_bias, grad_norm_weight, grad_norm_bias}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h index 70b3419d8..1dd84773a 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -19,25 +19,21 @@ template class MultiHeadAttention { public: - MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, - int hidden_size, int num_heads, float attn_dropout_ratio, - float hidden_output_dropout_ratio, + MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size, + int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio, bool pre_or_postLayerNorm); virtual ~MultiHeadAttention(); void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); - void Backward(const T *grad_output_ptr, const T *input_ptr, - const T *output_ptr, const T *input_mask_ptr, - T *grad_input_ptr); + void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, + const T *input_mask_ptr, T *grad_input_ptr); - void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, - T *buffer); + void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer); - void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, - const T *output_ptr, const T *grad_output_ptr, - T *grad_input_attn_layer_bwptr, T *buffer); + void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, + const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer); void set_cur_batch_shape(int batch_size, int seq_len) { _batch_size = batch_size; @@ -87,17 +83,14 @@ class MultiHeadAttention { } _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); - _soft_out_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _ctx_bufB_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _soft_out_ptr = cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _ctx_bufB_ptr = cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); // buffer size needed by attn bw - size_t smem_size = - 4 * _max_batch_tokens * _hidden_size / pg_size + - std::max(3 * _max_batch_tokens * _hidden_size / pg_size, - _max_batch_tokens * _heads / pg_size * _max_seq_len); + size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size + + std::max(3 * _max_batch_tokens * _hidden_size / pg_size, + _max_batch_tokens * _heads / pg_size * _max_seq_len); if (!_shared_mem_ptr) { cuda_free(_shared_mem_ptr); diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu index 41781ebc7..d2370e9f3 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu @@ -2,13 +2,12 @@ * with minor changes. */ #include -#include #include +#include #include #include -#include +#include #include - #include "scaled_masked_softmax.h" #include "type_shim.h" @@ -16,15 +15,17 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, - float scale_factor) { - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, - // seq_len] + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = input.size(0); const int pad_batches = mask.size(0); const int attn_heads = input.size(1); @@ -37,10 +38,10 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty( - {batches, attn_heads, query_seq_len, key_seq_len}, act_options); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -48,23 +49,31 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), "dispatch_scaled_masked_softmax_forward", + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", dispatch_scaled_masked_softmax_forward( reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), scale_factor, - query_seq_len, key_seq_len, batches, attn_heads, pad_batches);); + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); - // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, - // seq_len] + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = output_grads.size(0); const int attn_heads = output_grads.size(1); const int query_seq_len = output_grads.size(2); @@ -72,18 +81,24 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, void* output_grads_ptr = static_cast(output_grads.data_ptr()); - // Softmax Grad + //Softmax Grad DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward", + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, query_seq_len, key_seq_len, batches, attn_heads);); - - // backward pass is completely in-place + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place return output_grads; } -} // namespace scaled_masked_softmax -} // namespace fused_softmax -} // namespace multihead_attn +} +} +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp index cbbc37064..590ea7b3f 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp @@ -3,52 +3,57 @@ #include #include - #include namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return fwd_cuda(input, scale_factor); } -torch::Tensor bwd(torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, float scale_factor) { +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", + m.def("forward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + "Self Multihead Attention scaled, time masked softmax -- Backward."); } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu index 62c56e6f7..9dbb63476 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu @@ -2,13 +2,12 @@ * with minor changes. */ #include -#include #include +#include #include #include -#include +#include #include - #include "scaled_upper_triang_masked_softmax.h" #include "type_shim.h" @@ -16,15 +15,18 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = input.size(0); const int seq_len = input.size(1); TORCH_INTERNAL_ASSERT(seq_len <= 2048); - // Output + // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = + torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); // Softmax Intermediate Result Ptr @@ -34,42 +36,50 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { DISPATCH_HALF_AND_BFLOAT( input.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), scale_factor, seq_len, - seq_len, attn_batches);); + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); return softmax_results; } + -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); - // output grads is a 3d tensor with dimensions [attn_batches, seq_len, - // seq_len] + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = output_grads.size(0); const int seq_len = output_grads.size(1); TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); void* output_grads_ptr = static_cast(output_grads.data_ptr()); - // Softmax Grad + //Softmax Grad DISPATCH_HALF_AND_BFLOAT( output_grads_.scalar_type(), "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, seq_len, seq_len, attn_batches);); - - // backward pass is completely in-place + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place return output_grads; } -} // namespace scaled_upper_triang_masked_softmax -} // namespace fused_softmax -} // namespace multihead_attn +} +} +} diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 38e95e2f8..af66eb827 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -24,8 +24,8 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_, - ctx.eps) + output, mean, invvar = colossal_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -72,7 +72,8 @@ class MixedFusedLayerNorm(torch.nn.Module): def forward(self, input): - return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) + return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, + self.normalized_shape, self.eps) def __repr__(self): return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index cb36da8a1..786b922c6 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -28,7 +28,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') scale_t = torch.tensor([scale]) - softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + softmax_results = colossal_scaled_upper_triang_masked_softmax.forward( + inputs, scale_t[0] + ) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -41,7 +43,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + input_grads = colossal_scaled_upper_triang_masked_softmax.backward( + output_grads, softmax_results, scale_t[0] + ) return input_grads, None @@ -77,7 +81,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function): softmax_results, scale_t = ctx.saved_tensors - input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + input_grads = colossal_scaled_masked_softmax.backward( + output_grads, softmax_results, scale_t[0] + ) return input_grads, None, None @@ -108,8 +114,9 @@ class FusedScaleMaskSoftmax(nn.Module): super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 - and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion @@ -117,7 +124,9 @@ class FusedScaleMaskSoftmax(nn.Module): self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" + assert ( + self.scale is None or softmax_in_fp32 + ), "softmax should be in fp32 when scaled" def forward(self, input, mask): # [b, np, sq, sk] @@ -131,13 +140,14 @@ class FusedScaleMaskSoftmax(nn.Module): def is_kernel_available(self, mask, b, np, sq, sk): attn_batches = b * np - if (self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py index e6da70c40..f7a425dd5 100644 --- a/colossalai/kernel/jit/bias_gelu.py +++ b/colossalai/kernel/jit/bias_gelu.py @@ -1,5 +1,6 @@ import torch + ###### BIAS GELU FUSION/ NO AUTOGRAD ################ # 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2) -> 0.70710678 @@ -8,12 +9,10 @@ import torch # actual gelu is: # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) - @torch.jit.script def bias_gelu(bias, y): x = bias + y - return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) - + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) # gradient of tanh approximation of gelu # gradient of actual gelu is: @@ -24,11 +23,9 @@ def bias_gelu_back(g, bias, y): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff * g - + return ff*g class GeLUFunction(torch.autograd.Function): - @staticmethod # bias is an optional argument def forward(ctx, input, bias): @@ -41,5 +38,4 @@ class GeLUFunction(torch.autograd.Function): tmp = bias_gelu_back(grad_output, bias, input) return tmp, tmp - -bias_gelu_impl = GeLUFunction.apply +bias_gelu_impl = GeLUFunction.apply \ No newline at end of file diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index 5fc5c63e5..cec7cb8f7 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -182,7 +182,7 @@ class Linear2D(ParallelLayer): def forward(self, x: Tensor) -> Tensor: # input: [m/q, n/q, k/q] # output: [m/q, n/q, h/q] - out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + out_shape = x.shape[:-1] + (self.hidden_size_per_partition, ) output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, @@ -337,16 +337,16 @@ class LayerNorm2D(ParallelLayer): def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) Var_x /= self.normalized_shape - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) @@ -569,7 +569,7 @@ class PatchEmbedding2D(ParallelLayer): output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) @@ -1012,7 +1012,7 @@ class Classifier2D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - out_shape = input_.shape[:-1] + (self.num_classes,) + out_shape = input_.shape[:-1] + (self.num_classes, ) return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, @@ -1186,7 +1186,7 @@ class VocabParallelClassifier2D(ParallelLayer): def forward(self, x: Tensor) -> Tensor: # input: [m/q, n/q, k/q] # output: [m/q, n/q, h/q] - out_shape = x.shape[:-1] + (self.output_size_per_partition,) + out_shape = x.shape[:-1] + (self.output_size_per_partition, ) output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index f26efcc61..d89150642 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -189,7 +189,7 @@ class Linear2p5D(ParallelLayer): def forward(self, x: Tensor) -> Tensor: # input: [m/dq, n/q, k/q] # output: [m/dq, n/q, h/q] - out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + out_shape = x.shape[:-1] + (self.hidden_size_per_partition, ) output = Matmul_AB_2p5D.apply( x, @@ -254,7 +254,7 @@ class LayerNorm2p5D(ParallelLayer): self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() # partitioning dimension - self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * + self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * # create parameters factory_kwargs = {'device': get_current_device(), 'dtype': dtype} @@ -357,16 +357,16 @@ class LayerNorm2p5D(ParallelLayer): def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) Var_x /= self.normalized_shape - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) @@ -589,7 +589,7 @@ class PatchEmbedding2p5D(ParallelLayer): output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) @@ -1038,7 +1038,7 @@ class Classifier2p5D(ParallelLayer): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - out_shape = input_.shape[:-1] + (self.num_classes,) + out_shape = input_.shape[:-1] + (self.num_classes, ) return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, @@ -1172,7 +1172,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): def forward(self, x: Tensor) -> Tensor: # input: [m/dq, n/q, k/q] # output: [m/dq, n/q, h/q] - out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) + out_shape = x.shape[:-1] + (self.hidden_size_per_partition, ) output = Matmul_ABT_2p5D.apply( x, diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 33f358241..654d5d07f 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer): self.weight = Parameter( torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) if bias: - self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, + device=get_current_device(), dtype=dtype)) else: self.bias = None self.variance_epsilon = eps @@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer): input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py index f2297304f..a112f2d95 100644 --- a/colossalai/nn/layer/utils/common.py +++ b/colossalai/nn/layer/utils/common.py @@ -13,8 +13,7 @@ from torch import Tensor, nn class CheckpointModule(nn.Module): - - def __init__(self, checkpoint: bool = True, offload: bool = False): + def __init__(self, checkpoint: bool = True, offload : bool = False): super().__init__() self.checkpoint = checkpoint self._use_checkpoint = checkpoint @@ -79,7 +78,6 @@ def get_tensor_parallel_mode(): def _ntuple(n): - def parse(x): if isinstance(x, collections.abc.Iterable): return x