mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 20:39:48 +00:00
[NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958)
This commit is contained in:
parent
1dc1b6fa00
commit
89e2767a92
@ -16,7 +16,8 @@
|
|||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
|
||||||
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
|
template <typename T>
|
||||||
|
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,9 +29,10 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
|||||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename x_t> struct L2NormFunctor {
|
template <typename x_t>
|
||||||
__device__ __forceinline__ void
|
struct L2NormFunctor {
|
||||||
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
__device__ __forceinline__ void operator()(
|
||||||
|
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||||
float *output, float *output_per_tensor, bool per_tensor,
|
float *output, float *output_per_tensor, bool per_tensor,
|
||||||
int max_chunks_per_tensor) {
|
int max_chunks_per_tensor) {
|
||||||
// I'd like this kernel to propagate infs/nans.
|
// I'd like this kernel to propagate infs/nans.
|
||||||
@ -48,8 +50,8 @@ template <typename x_t> struct L2NormFunctor {
|
|||||||
|
|
||||||
__shared__ float s_vals[512];
|
__shared__ float s_vals[512];
|
||||||
|
|
||||||
float
|
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||||
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
|
// sure...
|
||||||
x_t r_x[ILP];
|
x_t r_x[ILP];
|
||||||
for (int i = 0; i < ILP; i++) {
|
for (int i = 0; i < ILP; i++) {
|
||||||
vals[i] = 0.f;
|
vals[i] = 0.f;
|
||||||
@ -84,8 +86,7 @@ template <typename x_t> struct L2NormFunctor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float val = 0.f;
|
float val = 0.f;
|
||||||
for (int i = 0; i < ILP; i++)
|
for (int i = 0; i < ILP; i++) val += vals[i];
|
||||||
val += vals[i];
|
|
||||||
|
|
||||||
float final = reduce_block_into_lanes(s_vals, val);
|
float final = reduce_block_into_lanes(s_vals, val);
|
||||||
|
|
||||||
@ -104,9 +105,10 @@ template <typename x_t> struct L2NormFunctor {
|
|||||||
|
|
||||||
// Probably better to template, but since we are not likely to support other
|
// Probably better to template, but since we are not likely to support other
|
||||||
// norm
|
// norm
|
||||||
template <typename x_t> struct MaxNormFunctor {
|
template <typename x_t>
|
||||||
__device__ __forceinline__ void
|
struct MaxNormFunctor {
|
||||||
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
__device__ __forceinline__ void operator()(
|
||||||
|
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||||
float *output, float *output_per_tensor, bool per_tensor,
|
float *output, float *output_per_tensor, bool per_tensor,
|
||||||
int max_chunks_per_tensor) {
|
int max_chunks_per_tensor) {
|
||||||
// I'd like this kernel to propagate infs/nans.
|
// I'd like this kernel to propagate infs/nans.
|
||||||
@ -124,8 +126,8 @@ template <typename x_t> struct MaxNormFunctor {
|
|||||||
|
|
||||||
__shared__ float s_vals[512];
|
__shared__ float s_vals[512];
|
||||||
|
|
||||||
float
|
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||||
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
|
// sure...
|
||||||
x_t r_x[ILP];
|
x_t r_x[ILP];
|
||||||
for (int i = 0; i < ILP; i++) {
|
for (int i = 0; i < ILP; i++) {
|
||||||
vals[i] = 0.f;
|
vals[i] = 0.f;
|
||||||
@ -160,8 +162,7 @@ template <typename x_t> struct MaxNormFunctor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float val = 0.f;
|
float val = 0.f;
|
||||||
for (int i = 0; i < ILP; i++)
|
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
|
||||||
val = fmaxf(fabsf(val), fabsf(vals[i]));
|
|
||||||
|
|
||||||
float final = reduce_block_into_lanes_max_op(s_vals, val);
|
float final = reduce_block_into_lanes_max_op(s_vals, val);
|
||||||
|
|
||||||
@ -185,13 +186,11 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
|
|||||||
|
|
||||||
if (blockIdx.x == 0) {
|
if (blockIdx.x == 0) {
|
||||||
float val = 0;
|
float val = 0;
|
||||||
if (threadIdx.x < 320)
|
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||||
val = output[threadIdx.x];
|
|
||||||
|
|
||||||
float final = reduce_block_into_lanes(vals, val);
|
float final = reduce_block_into_lanes(vals, val);
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0) *ret = sqrt(final);
|
||||||
*ret = sqrt(final);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (per_tensor) {
|
if (per_tensor) {
|
||||||
@ -204,8 +203,7 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
|
|||||||
|
|
||||||
float final = reduce_block_into_lanes(vals, val);
|
float final = reduce_block_into_lanes(vals, val);
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
|
||||||
ret_per_tensor[blockIdx.x] = sqrt(final);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,17 +215,14 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
|
|||||||
|
|
||||||
if (blockIdx.x == 0) {
|
if (blockIdx.x == 0) {
|
||||||
float val = 0;
|
float val = 0;
|
||||||
if (threadIdx.x < 320)
|
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||||
val = output[threadIdx.x];
|
|
||||||
|
|
||||||
if (norm_type == 0) {
|
if (norm_type == 0) {
|
||||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
|
||||||
*ret = alpha * (*ret) + beta * final;
|
|
||||||
} else {
|
} else {
|
||||||
float final = reduce_block_into_lanes(vals, val);
|
float final = reduce_block_into_lanes(vals, val);
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
|
||||||
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,8 +255,8 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||||
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
|
int chunk_size, at::Tensor noop_flag,
|
||||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||||
at::optional<bool> per_tensor_python) {
|
at::optional<bool> per_tensor_python) {
|
||||||
bool per_tensor =
|
bool per_tensor =
|
||||||
|
Loading…
Reference in New Issue
Block a user