diff --git a/extensions/csrc/common/dev_info_mgr.h b/extensions/csrc/common/dev_info_mgr.h new file mode 100644 index 000000000..7570666ad --- /dev/null +++ b/extensions/csrc/common/dev_info_mgr.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "common/nvgpu_dev_info.h" +#include "target.h" + +namespace colossalAI { +namespace common { + +template +class DevInfoMgr final { + public: + static std::unique_ptr GetDevInfo(int device_num) const { + return std::make_unique(device_num); + } +}; + +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/cuda/type_shim.h b/extensions/csrc/common/micros.h similarity index 87% rename from extensions/csrc/cuda/type_shim.h rename to extensions/csrc/common/micros.h index 7be3fab1b..c2241029f 100644 --- a/extensions/csrc/cuda/type_shim.h +++ b/extensions/csrc/common/micros.h @@ -9,7 +9,15 @@ #include -#include "compat.h" +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \ @@ -214,90 +222,3 @@ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ } - -template -__device__ __forceinline__ T reduce_block_into_lanes( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = x[tid] + x[tid + i]; - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = x[tid] + x[tid + 32]; - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = final + __shfl_down_sync(0xffffffff, final, i); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -template -__device__ __forceinline__ T reduce_block_into_lanes_max_op( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = - fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} diff --git a/extensions/csrc/cuda/include/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h similarity index 75% rename from extensions/csrc/cuda/include/mp_type_traits.h rename to extensions/csrc/common/mp_type_traits.h index 6b3ae9c1b..8ede2d448 100644 --- a/extensions/csrc/cuda/include/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -2,10 +2,10 @@ #include -#include "../type_shim.h" +#include "micros.h" -namespace infer { -namespace dtype { +namespace colossalAI { +namespace common { template class MPTypeTrait { @@ -31,5 +31,5 @@ class MPTypeTrait { using Type = float; }; -} // namespace dtype -} // namespace infer +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/common/target.h b/extensions/csrc/common/target.h new file mode 100644 index 000000000..1c8a508e3 --- /dev/null +++ b/extensions/csrc/common/target.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include + +namespace colossalAI { +namespace common { + +class Target { + public: + enum class OS : int { + Unk = -1, + Linux, + Windows, + }; + enum class Arch : int { + Unk = -1, + X86, + Arm, + NVGPU, + AMDGPU, + Ascend, + }; + enum class BitLen : int { + Unk = -1, + k32, + k64, + }; + + explicit Target(OS os, Arch arch, BitLen bitlen) + : os_(os), arch_(arch), bitlen_(bitlen) {} + + bool defined() const { + return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk); + } + + std::string str() const { + std::string s{"OS: "}; + switch (os_) { + case OS::Unk: + s += "Unk"; + break; + case OS::Linux: + s += "Linux"; + break; + case OS::Windows: + s += "Windows"; + break; + default: + throw std::invalid_argument("Invalid OS type!"); + } + s += "\t"; + s += "Arch: "; + + switch (arch_) { + case Arch::Unk: + s += "Unk"; + break; + case Arch::X86: + s += "X86"; + break; + case Arch::Arm: + s += "Arm"; + break; + case Arch::NVGPU: + s += "NVGPU"; + break; + case Arch::AMDGPU: + s += "AMDGPU"; + break; + case Arch::Ascend: + s += "Ascend"; + break; + default: + throw std::invalid_argument("Invalid Arch type!"); + } + s += "\t"; + s += "BitLen: "; + + switch (bitlen_) { + case BitLen::Unk: + s += "Unk"; + break; + case BitLen::k32: + s += "k32"; + break; + case BitLen::k64: + s += "k64"; + break; + default: + throw std::invalid_argument("Invalid target bit length!"); + } + + return s; + } + + OS os() const { return os_; } + Arch arch() const { return arch_; } + BitLen bitlen() const { return bitlen_; } + + static Target DefaultX86Target(); + static Target DefaultArmTarget(); + static Target DefaultRocmTarget(); + static Target DefaultAscendTarget(); + + static Target DefaultCUDATarget() { + return Target(OS::Linux, Arch::CUDA, BitLen::k64); + } + + friend std::ostream& operator<<(std::ostream& os, const Target& target); + friend bool operator==(const Target& lhs, const Target& rhs); + friend bool operator!=(const Target& lhs, const Target& rhs); + + private: + OS os_{OS::Unk}; + Arch arch_{Arch::Unk}; + BitLen bitlen_{BitLen::Unk}; +}; + +std::ostream& operator<<(std::ostream& os, const Target& target) { + std::cout << target.str() << std::endl; +} +bool operator==(const Target& lhs, const Target& rhs) { + return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) && + (lhs.bitlen_ == rhs.bitlen_); +} +bool operator!=(const Target& lhs, const Target& rhs) { + return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) && + (lhs.bitlen_ != rhs.bitlen_); +} + +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index 4121b67fc..5213a2313 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -2,13 +2,13 @@ #include #include -#include "type_shim.h" -#include "include/mp_type_traits.h" +#include "../common/micros.h" +#include "../common/mp_type_traits.h" template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - using MT = typename infer::dtype::MPTypeTrait::Type; + using MT = typename colossalAI::common::MPTypeTrait::Type; return static_cast((static_cast(x)) / (static_cast(1.0f) + expf(static_cast(-x)))); } @@ -17,7 +17,7 @@ __global__ void act_and_mul_kernel( const scalar_t* __restrict__ ins_data, scalar_t* __restrict__ outs_data, const int64_t numel) { - using MT = typename infer::dtype::MPTypeTrait::Type; + using MT = typename colossalAI::common::MPTypeTrait::Type; int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); const int64_t grid_size = blockDim.x * gridDim.x; diff --git a/extensions/csrc/cuda/compat.h b/extensions/csrc/cuda/compat.h index a62beef91..e69de29bb 100644 --- a/extensions/csrc/cuda/compat.h +++ b/extensions/csrc/cuda/compat.h @@ -1,10 +0,0 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 86db90c8b..15e613e35 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -2,7 +2,7 @@ #include #include -#include "type_shim.h" +#include "../common/micros.h" template __global__ void decode_kv_cache_memcpy_kernel( diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h index 38103c173..86409136b 100644 --- a/extensions/csrc/cuda/include/block_reduce.h +++ b/extensions/csrc/cuda/include/block_reduce.h @@ -310,3 +310,90 @@ __inline__ __device__ void blockReduce(float *pval) { } warpReduce(pval); } + +template +__device__ __forceinline__ T reduce_block_into_lanes( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = + fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} diff --git a/extensions/csrc/cuda/layer_norm_cuda.cpp b/extensions/csrc/cuda/layer_norm_cuda.cpp index 15a07bb0c..3439e5e71 100644 --- a/extensions/csrc/cuda/layer_norm_cuda.cpp +++ b/extensions/csrc/cuda/layer_norm_cuda.cpp @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" namespace { diff --git a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu index 72b84d6ca..17d5b10f4 100644 --- a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu +++ b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu @@ -9,7 +9,7 @@ #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/DeviceUtils.cuh" -#include "type_shim.h" +#include "../common/micros.h" template __device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { diff --git a/extensions/csrc/cuda/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam.cu index 9cc3ae1ea..b7793b364 100644 --- a/extensions/csrc/cuda/multi_tensor_adam.cu +++ b/extensions/csrc/cuda/multi_tensor_adam.cu @@ -15,7 +15,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh index ec55dd320..01a858661 100644 --- a/extensions/csrc/cuda/multi_tensor_apply.cuh +++ b/extensions/csrc/cuda/multi_tensor_apply.cuh @@ -12,7 +12,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" // #include diff --git a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu index 85f935152..57a79f7a8 100644 --- a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu @@ -11,7 +11,8 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" +#include "include/block_reduce.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb.cu index 63771cf40..50dfc56bc 100644 --- a/extensions/csrc/cuda/multi_tensor_lamb.cu +++ b/extensions/csrc/cuda/multi_tensor_lamb.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu index 2f58a0f16..0dec1d5d1 100644 --- a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu index 7f48dbd5d..d0cf786f8 100644 --- a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" #include "multi_tensor_apply.cuh" #define BLOCK_SIZE 512 diff --git a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu index 41781ebc7..2f968d30f 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu +++ b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu @@ -10,7 +10,7 @@ #include #include "scaled_masked_softmax.h" -#include "type_shim.h" +#include "../common/micros.h" namespace multihead_attn { namespace fused_softmax { diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu index 62c56e6f7..d9550dc2c 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu @@ -10,7 +10,7 @@ #include #include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" +#include "../common/micros.h" namespace multihead_attn { namespace fused_softmax { diff --git a/extensions/csrc/cuda/utils/gpu_launch_config.h b/extensions/csrc/cuda/utils/gpu_launch_config.h new file mode 100644 index 000000000..c7481323a --- /dev/null +++ b/extensions/csrc/cuda/utils/gpu_launch_config.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +namespace colossalAI { +namespace cuda { +namespace utils { + +GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); + +// TODO(LiuYang): to be implemented +GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size); + +// TODO(LiuYang): to be implemented +GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size); + +class GPULaunchConfig { + public: + GPULaunchConfig(){}; + GPULaunchConfig(const dim3& block, const dim3& grid) + : block_(block), grid_(grid) {} + friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); + + protected: + void set_block(const dim3& dim) { block_ = dim; } + void set_grid(const dim3& dim) { grid_ = dim; } + + private: + dim3 block_(1, 1, 1); + dim3 grid_(1, 1, 1); +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/micros.h b/extensions/csrc/cuda/utils/micros.h new file mode 100644 index 000000000..9b410e3d8 --- /dev/null +++ b/extensions/csrc/cuda/utils/micros.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +#define CUDA_CHECK(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \ + } \ + } diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.cc b/extensions/csrc/cuda/utils/nvgpu_dev_info.cc new file mode 100644 index 000000000..e52abebff --- /dev/null +++ b/extensions/csrc/cuda/utils/nvgpu_dev_info.cc @@ -0,0 +1,45 @@ +#include "nvgpu_dev_info.h" + +#include + +namespace colossalAI { +namespace cuda { +namespace utils { + +std::array NVGPUDevInfo::GetMaxGridDims() const { + std::array ret; + ret[0] = prop_->maxGridSize[0]; + ret[1] = prop_->maxGridSize[1]; + ret[2] = prop_->maxGridSize[2]; + return ret; +} + +std::array NVGPUDevInfo::GetMaxBlockDims() const { + std::array ret; + ret[0] = prop_->maxThreadsDim[0]; + ret[1] = prop_->maxThreadsDim[1]; + ret[2] = prop_->maxThreadsDim[2]; + return ret; +} + +std::array NVGPUDevInfo::GetCapability() const { + std::array ret; + ret[0] = prop_.major; + ret[1] = prop_.minor; +} + +int NVGPUDevInfo::GetMultiProcessorCount() const { + return prop_->multiProcessorCount; +} + +int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const { + return prop_->maxThreadsPerMultiProcessor; +} + +int NVGPUDevInfo::GetMaxThreadsPerBlock() const { + return prop_->maxThreadsPerBlock; +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.h b/extensions/csrc/cuda/utils/nvgpu_dev_info.h new file mode 100644 index 000000000..c8c67c908 --- /dev/null +++ b/extensions/csrc/cuda/utils/nvgpu_dev_info.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "micros.h" +#include "target.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +class NVGPUDevInfo { + public: + explicit NVGPUDevInfo(int device_num) : device_num_(device_num) { + CUDA_CALL(cudaGetDeviceProperties(prop_, device)); + } + + std::array GetMaxGridDims() const; + std::array GetMaxBlockDims() const; + std::array GetCapability() const; + int GetMultiProcessorCount() const; + int GetMaxThreadsPerMultiProcessor() const; + int GetMaxThreadsPerBlock() const; + + private: + int device_num_; + cudaDeviceProp* prop_; +}; + +} // namespace utils +} // namespace cuda +} // namespace colossalAI