[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)

* refactor compilation mechanism and unified multi hw

* fix file path bug

* add init.py to make pybind a module to avoid relative path error caused by softlink

* delete duplicated micros

* fix micros bug in gcc
This commit is contained in:
傅剑寒
2024-04-24 14:17:54 +08:00
committed by GitHub
parent 04863a9b14
commit 279300dc5f
64 changed files with 345 additions and 310 deletions

View File

@@ -0,0 +1,78 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "nvgpu_dev_info.h"
namespace colossalAI {
namespace cuda {
namespace utils {
struct GPULaunchConfig {
dim3 block{1, 1, 1};
dim3 grid{1, 1, 1};
};
static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info,
int64_t numel, int64_t vec_size) {
const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock();
const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0];
const int64_t kMinimumSize = 64;
const int64_t kMaximumSize = 512;
int64_t active_threads = (numel + vec_size - 1) / vec_size;
int64_t sm_num = dev_info.GetMultiProcessorCount();
// Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally
int64_t expected_threads_per_block = kMaximumSize;
auto RoundUpToPowerOfTwo = [](int64_t x) {
bool is_power_of_two = false;
int64_t ret = 1;
int64_t y = x;
while (y > 0) {
is_power_of_two = ((ret ^ x) == 0);
y = (x >> 1);
ret = (ret << 1);
if (y > 0) is_power_of_two = false;
}
if (is_power_of_two) return x;
return ret;
};
if ((active_threads / (sm_num << 1)) < max_threads_per_block) {
expected_threads_per_block =
RoundUpToPowerOfTwo(active_threads / (sm_num << 1));
} else if ((active_threads / (sm_num << 2)) < max_threads_per_block) {
expected_threads_per_block =
RoundUpToPowerOfTwo(active_threads / (sm_num << 2));
}
expected_threads_per_block =
std::max(expected_threads_per_block, kMinimumSize);
int64_t expect_block_per_grid =
((active_threads + expected_threads_per_block - 1) /
expected_threads_per_block);
if (expect_block_per_grid > max_blocks_per_grid) {
expect_block_per_grid = max_blocks_per_grid;
expected_threads_per_block =
(active_threads + expect_block_per_grid - 1) / expect_block_per_grid;
if (expected_threads_per_block > max_threads_per_block)
throw std::invalid_argument(
"Threads required for current input exceed for current GPU!");
expected_threads_per_block =
RoundUpToPowerOfTwo(expected_threads_per_block);
expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) /
expected_threads_per_block);
}
GPULaunchConfig config;
config.block.x = expected_threads_per_block;
config.grid.x = expect_block_per_grid;
return config;
}
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@@ -0,0 +1,18 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <exception>
#define CUDA_CHECK(func) \
{ \
auto status = func; \
if (status != cudaSuccess) { \
throw std::runtime_error(cudaGetErrorString(status)); \
} \
}
#define HOST __host__
#define DEVICE __device__
#define HOSTDEVICE __host__ __device__

View File

@@ -0,0 +1,60 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <ostream>
#include <string>
#include <vector>
#include "micros.h"
namespace colossalAI {
namespace cuda {
namespace utils {
class NVGPUDevInfo {
public:
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num));
}
std::array<int, 3> GetMaxGridDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxGridSize[0];
ret[1] = prop_.maxGridSize[1];
ret[2] = prop_.maxGridSize[2];
return ret;
}
std::array<int, 3> GetMaxBlockDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxThreadsDim[0];
ret[1] = prop_.maxThreadsDim[1];
ret[2] = prop_.maxThreadsDim[2];
return ret;
}
std::array<int, 2> GetCapability() const {
std::array<int, 2> ret;
ret[0] = prop_.major;
ret[1] = prop_.minor;
return ret;
}
int GetMultiProcessorCount() const { return prop_.multiProcessorCount; }
int GetMaxThreadsPerMultiProcessor() const {
return prop_.maxThreadsPerMultiProcessor;
}
int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; }
private:
int device_num_;
cudaDeviceProp prop_;
};
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@@ -0,0 +1,60 @@
#pragma once
#include <cuda_fp16.h>
#include <stdint.h>
#include "common/vec_type_traits.h"
#include "funcs/cast_functor.h"
namespace colossalAI {
namespace cuda {
namespace utils {
template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
}
template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<const float4 *>(src + 4));
}
template <typename T, int VecSize>
__device__ __inline__ void copy_zero_vector(T *dst) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
}
template <typename T>
int get_vec_size(const torch::Tensor &tensor) {
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
const int max_aligned_size = 128;
const int dtype_size = sizeof(T) * 8;
const int vec_size = max_aligned_size / sizeof(T) / 8;
// Note(LiuYang): Performance of situation of which
// vec_size equals to 8 need to be profiled in the future
// if (address % (dtype_size * 8) == 0) {
// return std::min(8, vec_size);
// }
if (address % (dtype_size * 4) == 0) {
return std::min(4, vec_size);
} else if (address % (dtype_size * 2) == 0) {
return std::min(2, vec_size);
} else {
return 1;
}
}
} // namespace utils
} // namespace cuda
} // namespace colossalAI