From 6fcb381801761aded63d81fc6bb180a46f00d6e7 Mon Sep 17 00:00:00 2001 From: Wangbo Zhao <56866854+wangbo-zhao@users.noreply.github.com> Date: Sat, 2 Apr 2022 10:45:04 +0800 Subject: [PATCH] [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#635) --- .../csrc/multi_tensor_l2norm_kernel.cu | 676 ++++++++---------- 1 file changed, 304 insertions(+), 372 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu index 03f60b34c..8686e83f8 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -1,4 +1,5 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu #include #include #include @@ -9,356 +10,311 @@ #include -#include "type_shim.h" #include "multi_tensor_apply.cuh" +#include "type_shim.h" #define BLOCK_SIZE 512 #define ILP 4 -template -__device__ __forceinline__ bool is_aligned(T *p) -{ - return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +template __device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; } template -__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset) -{ - typedef typename std::aligned_storage::type LT; - ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, + int src_offset) { + typedef + typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; } -template -struct L2NormFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int *noop_gmem, - TensorListMetadata<1> &tl, - float *output, - float *output_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; +template struct L2NormFunctor { + __device__ __forceinline__ void + operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; - x_t *x = (x_t *)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; - n -= chunk_idx * chunk_size; + n -= chunk_idx * chunk_size; - __shared__ float s_vals[512]; + __shared__ float s_vals[512]; - float vals[ILP]; // = {0}; // this probably works too but I want to be sure... - x_t r_x[ILP]; - for (int i = 0; i < ILP; i++) - { - vals[i] = 0.f; - r_x[i] = 0; - } - - // to make things simple, we put aligned case in a different code path - if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) - { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_x, x, 0, i_start); -#pragma unroll - for (int ii = 0; ii < ILP; ii++) - { - float next = static_cast(r_x[ii]); - vals[ii] += next * next; - } - } - } - else - { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) - { -#pragma unroll - for (int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) - { - float next = static_cast(x[i]); - vals[ii] += next * next; - } - } - } - } - - float val = 0.f; - for (int i = 0; i < ILP; i++) - val += vals[i]; - - float final = reduce_block_into_lanes(s_vals, val); - - if (threadIdx.x == 0) - { - if (!isfinite(final)) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - output[blockIdx.x] += final; - if (per_tensor) - output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final; - } + float + vals[ILP]; // = {0}; // this probably works too but I want to be sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] += next * next; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] += next * next; + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) + val += vals[i]; + + float final = reduce_block_into_lanes(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] += final; + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } }; -// Probably better to template, but since we are not likely to support other norm -template -struct MaxNormFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int *noop_gmem, - TensorListMetadata<1> &tl, - float *output, - float *output_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; +// Probably better to template, but since we are not likely to support other +// norm +template struct MaxNormFunctor { + __device__ __forceinline__ void + operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, + float *output, float *output_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + // I'd like this kernel to propagate infs/nans. + // if(*noop_gmem == 1) + // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; - x_t *x = (x_t *)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; + x_t *x = (x_t *)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; - n -= chunk_idx * chunk_size; + n -= chunk_idx * chunk_size; - __shared__ float s_vals[512]; + __shared__ float s_vals[512]; - float vals[ILP]; // = {0}; // this probably works too but I want to be sure... - x_t r_x[ILP]; - for (int i = 0; i < ILP; i++) - { - vals[i] = 0.f; - r_x[i] = 0; - } - - // to make things simple, we put aligned case in a different code path - if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) - { - for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_x, x, 0, i_start); -#pragma unroll - for (int ii = 0; ii < ILP; ii++) - { - float next = static_cast(r_x[ii]); - vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); - } - } - } - else - { - for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) - { -#pragma unroll - for (int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) - { - float next = static_cast(x[i]); - vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); - } - } - } - } - - float val = 0.f; - for (int i = 0; i < ILP; i++) - val = fmaxf(fabsf(val), fabsf(vals[i])); - - float final = reduce_block_into_lanes_max_op(s_vals, val); - - if (threadIdx.x == 0) - { - if (!isfinite(final)) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); - if (per_tensor) - output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final; - } + float + vals[ILP]; // = {0}; // this probably works too but I want to be sure... + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) { + vals[i] = 0.f; + r_x[i] = 0; } + + // to make things simple, we put aligned case in a different code path + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; + i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + // load + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float next = static_cast(r_x[ii]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float next = static_cast(x[i]); + vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); + } + } + } + } + + float val = 0.f; + for (int i = 0; i < ILP; i++) + val = fmaxf(fabsf(val), fabsf(vals[i])); + + float final = reduce_block_into_lanes_max_op(s_vals, val); + + if (threadIdx.x == 0) { + if (!isfinite(final)) + *noop_gmem = + 1; // Blindly fire off a write. These will race but that's ok. + output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); + if (per_tensor) + output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * + max_chunks_per_tensor + + chunk_idx] = final; + } + } }; -__global__ void cleanup( - float *output, - float *output_per_tensor, - float *ret, - float *ret_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) -{ - __shared__ float vals[512]; +__global__ void cleanup(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor) { + __shared__ float vals[512]; - if (blockIdx.x == 0) - { - float val = 0; - if (threadIdx.x < 320) - val = output[threadIdx.x]; + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) + val = output[threadIdx.x]; - float final = reduce_block_into_lanes(vals, val); + float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) - *ret = sqrt(final); - } + if (threadIdx.x == 0) + *ret = sqrt(final); + } - if (per_tensor) - { - float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor; + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; - float val = 0; - for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val += output_this_tensor[i]; + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; - float final = reduce_block_into_lanes(vals, val); + float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = sqrt(final); - } + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = sqrt(final); + } } -__global__ void cleanup_v2( - float *output, - float *output_per_tensor, - float *ret, - float *ret_per_tensor, - bool per_tensor, - int max_chunks_per_tensor, - int norm_type, - float alpha, - float beta) -{ - __shared__ float vals[512]; +__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, + float *ret_per_tensor, bool per_tensor, + int max_chunks_per_tensor, int norm_type, + float alpha, float beta) { + __shared__ float vals[512]; - if (blockIdx.x == 0) - { - float val = 0; - if (threadIdx.x < 320) - val = output[threadIdx.x]; + if (blockIdx.x == 0) { + float val = 0; + if (threadIdx.x < 320) + val = output[threadIdx.x]; - if (norm_type == 0) - { - float final = reduce_block_into_lanes_max_op(vals, val); - if (threadIdx.x == 0) - *ret = alpha * (*ret) + beta * final; - } - else - { - float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) - *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); - } + if (norm_type == 0) { + float final = reduce_block_into_lanes_max_op(vals, val); + if (threadIdx.x == 0) + *ret = alpha * (*ret) + beta * final; + } else { + float final = reduce_block_into_lanes(vals, val); + if (threadIdx.x == 0) + *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); } + } - if (per_tensor) - { - float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor; + if (per_tensor) { + float *output_this_tensor = + output_per_tensor + blockIdx.x * max_chunks_per_tensor; - if (norm_type == 0) - { - float val = 0; - for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val = fmaxf(fabsf(val), fabsf(output_this_tensor[i])); + if (norm_type == 0) { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val = fmaxf(fabsf(val), fabsf(output_this_tensor[i])); - float final = reduce_block_into_lanes_max_op(vals, val); + float final = reduce_block_into_lanes_max_op(vals, val); - if (threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final; - } - else - { - float val = 0; - for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val += output_this_tensor[i]; + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = + alpha * ret_per_tensor[blockIdx.x] + beta * final; + } else { + float val = 0; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) + val += output_this_tensor[i]; - float final = reduce_block_into_lanes(vals, val); + float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final); - } + if (threadIdx.x == 0) + ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * + ret_per_tensor[blockIdx.x] + + beta * final); } + } } -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python) -{ - bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; +std::tuple +multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { + bool per_tensor = + per_tensor_python.has_value() ? per_tensor_python.value() : false; - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + auto output = at::zeros({320}, float_options); - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; - if (per_tensor) - { - for (int t = 0; t < ntensors; t++) - { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - if (max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } - else - { - ret_per_tensor = at::empty({0}, float_options); + if (per_tensor) { + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; } + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); + ret_per_tensor = at::empty({ntensors}, float_options); + } else { + ret_per_tensor = at::empty({0}, float_options); + } - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", - multi_tensor_apply<1>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - L2NormFunctor(), - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - per_tensor, - max_chunks_per_tensor);) + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.DATA_PTR(), + per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + per_tensor, max_chunks_per_tensor);) - AT_CUDA_CHECK(cudaGetLastError()); - // AT_CUDA_CHECK(cudaDeviceSynchronize()); + AT_CUDA_CHECK(cudaGetLastError()); + // AT_CUDA_CHECK(cudaDeviceSynchronize()); - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now - auto ret = at::empty({1}, output.options()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - auto stream = at::cuda::getCurrentCUDAStream(); - cleanup<<>>( - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - ret.DATA_PTR(), - per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, - per_tensor, - max_chunks_per_tensor); + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup<<>>( + output.DATA_PTR(), + per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + ret.DATA_PTR(), + per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, per_tensor, + max_chunks_per_tensor); - return std::tuple(ret, ret_per_tensor); + return std::tuple(ret, ret_per_tensor); } // Compute and update grad norm @@ -366,90 +322,66 @@ std::tuple multi_tensor_l2norm_cuda( // L-2: gn = sqrt(a * gn^2 + b * n^2) // L-inf: gn = a * gn + b * n void multi_tensor_norm_out_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor out, - const float alpha, - const float beta, - const int norm_type) -{ - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors"); - // we don't need global thus uses empty here - auto output = at::empty({320}, float_options); + int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor out, + const float alpha, const float beta, const int norm_type) { + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), + "noop flag should be on the same device as tensors"); + // we don't need global thus uses empty here + auto output = at::empty({320}, float_options); - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; + at::Tensor output_per_tensor; + at::Tensor ret_per_tensor; - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = -1; - for (int t = 0; t < ntensors; t++) - { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - if (max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = + (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) + max_chunks_per_tensor = max_chunks_this_tensor; + } - // Although it is single write then read, still need to be zero - // Since tailing element also participate cleanup - output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options); + // Although it is single write then read, still need to be zero + // Since tailing element also participate cleanup + output_per_tensor = + at::zeros({ntensors * max_chunks_per_tensor}, float_options); - if (norm_type == 0) - { - DISPATCH_FLOAT_AND_HALF( - tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", - multi_tensor_apply<1>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - MaxNormFunctor(), - output.DATA_PTR(), - output_per_tensor.DATA_PTR(), - true, - max_chunks_per_tensor);) - } - else - { - DISPATCH_FLOAT_AND_HALF( - tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", - multi_tensor_apply<1>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - L2NormFunctor(), - output.DATA_PTR(), - output_per_tensor.DATA_PTR(), - true, - max_chunks_per_tensor);) - } - AT_CUDA_CHECK(cudaGetLastError()); + if (norm_type == 0) { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + MaxNormFunctor(), output.DATA_PTR(), + output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + } else { + DISPATCH_FLOAT_AND_HALF( + tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + L2NormFunctor(), output.DATA_PTR(), + output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + } + AT_CUDA_CHECK(cudaGetLastError()); - // AT_CUDA_CHECK(cudaDeviceSynchronize()); + // AT_CUDA_CHECK(cudaDeviceSynchronize()); - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now - auto ret = at::empty({1}, output.options()); + // This involves one more small kernel launches, but will be negligible end to + // end. I could get rid of these by hacking the functor + multi tensor harness + // with persistence logic, but keeping it simple for now + auto ret = at::empty({1}, output.options()); - // Adding the following device guard since it happens sometimes that the - // tensors are on one device and the cuda stream is on another device which - // results in ILLEGAL MEM ACCESS error. - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - auto stream = at::cuda::getCurrentCUDAStream(); - cleanup_v2<<>>( - output.DATA_PTR(), - output_per_tensor.DATA_PTR(), - ret.DATA_PTR(), - out.DATA_PTR(), - true, - max_chunks_per_tensor, - norm_type, - alpha, - beta); + // Adding the following device guard since it happens sometimes that the + // tensors are on one device and the cuda stream is on another device which + // results in ILLEGAL MEM ACCESS error. + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + auto stream = at::cuda::getCurrentCUDAStream(); + cleanup_v2<<>>( + output.DATA_PTR(), output_per_tensor.DATA_PTR(), + ret.DATA_PTR(), out.DATA_PTR(), true, max_chunks_per_tensor, + norm_type, alpha, beta); - return; + return; } \ No newline at end of file