Recover kernal files

This commit is contained in:
binmakeswell 2022-07-13 11:57:42 +08:00 committed by Frank Lee
parent e83b2ce853
commit 7696cead8d
8 changed files with 1217 additions and 1303 deletions

View File

@ -1,10 +1,11 @@
#include <cooperative_groups.h>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate; curandStatePhilox4_32_10_t *curandstate;

View File

@ -3,11 +3,10 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include <stdexcept>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <stdexcept>
#define MAX_THREADS 1024 #define MAX_THREADS 1024
#define WARP_SIZE 32 #define WARP_SIZE 32
@ -133,9 +132,8 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
} }
/* Convert 4-dim tensor index into vector index */ /* Convert 4-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, __forceinline__ __host__ __device__ int
int id4, int dim2, int dim3, flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) {
int dim4) {
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
int res = id4; int res = id4;
@ -203,9 +201,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
} }
/* Convert vector index to 6-dim tensor index */ /* Convert vector index to 6-dim tensor index */
__forceinline__ __host__ __device__ void decompose_6dim( __forceinline__ __host__ __device__ void
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5,
int *id1, int *id2, int *id3, int *id4, int *id5) { int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) {
*id5 = src % dim5; *id5 = src % dim5;
src /= dim5; src /= dim5;
@ -223,11 +221,9 @@ __forceinline__ __host__ __device__ void decompose_6dim(
} }
/* Convert vector index to 5-dim tensor index */ /* Convert vector index to 5-dim tensor index */
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, __forceinline__ __host__ __device__ void
int dim2, int dim3, decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0,
int dim4, int *id0, int *id1, int *id2, int *id3, int *id4) {
int *id1, int *id2,
int *id3, int *id4) {
*id4 = src % dim4; *id4 = src % dim4;
src /= dim4; src /= dim4;
@ -257,9 +253,8 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
} }
/* Convert vector index to 3-dim tensor index */ /* Convert vector index to 3-dim tensor index */
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, __forceinline__ __host__ __device__ void
int dim2, int *id0, decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) {
int *id1, int *id2) {
*id2 = src % dim2; *id2 = src % dim2;
src /= dim2; src /= dim2;

View File

@ -135,10 +135,9 @@ __global__ void bias_add_transform_20314(T *output, const T *input,
const T *bias, int dim_3, int dim_4); const T *bias, int dim_3, int dim_4);
template <> template <>
__global__ void bias_add_transform_20314<float>(float *output, __global__ void
const float *input, bias_add_transform_20314<float>(float *output, const float *input,
const float *bias, int dim_3, const float *bias, int dim_3, int dim_4) {
int dim_4) {
int id0 = blockIdx.x; int id0 = blockIdx.x;
int id1 = blockIdx.y; int id1 = blockIdx.y;
int id2 = blockIdx.z; int id2 = blockIdx.z;
@ -174,10 +173,9 @@ __global__ void bias_add_transform_20314<float>(float *output,
} }
template <> template <>
__global__ void bias_add_transform_20314<__half>(__half *output, __global__ void
const __half *input, bias_add_transform_20314<__half>(__half *output, const __half *input,
const __half *bias, int dim_3, const __half *bias, int dim_3, int dim_4) {
int dim_4) {
int id0 = blockIdx.x; int id0 = blockIdx.x;
int id1 = blockIdx.y; int id1 = blockIdx.y;
int id2 = blockIdx.z; int id2 = blockIdx.z;

View File

@ -1,14 +1,13 @@
// modified from // modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "compat.h" #include "compat.h"
#include <assert.h>
// #include <iostream> // #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs. // This header is the one-stop shop for all your multi-tensor apply needs.
@ -18,52 +17,54 @@ constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n> template <int n>
struct TensorListMetadata { struct TensorListMetadata
{
void *addresses[n][depth_to_max_tensors[n - 1]]; void *addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]]; int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
// full int.
int start_tensor_this_launch; int start_tensor_this_launch;
}; };
template <typename T, typename U, typename... ArgTypes> template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size, __global__ void multi_tensor_apply_kernel(
volatile int *noop_flag, T tl, int chunk_size,
U callable, ArgTypes... args) { volatile int *noop_flag,
// Hand the chunk information to the user-supplied functor to process however T tl,
// it likes. U callable,
ArgTypes... args)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...); callable(chunk_size, noop_flag, tl, args...);
} }
template <int depth, typename T, typename... ArgTypes> template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply( void multi_tensor_apply(
int block_size, int chunk_size, const at::Tensor &noop_flag, int block_size,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable, int chunk_size,
ArgTypes... args) { const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size(); int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device(); auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
l++) // No range-based for because I need indices {
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++)
{ {
TORCH_CHECK(tensor_lists[l].size() == len0,
"Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails. // TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous(); bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5 #ifdef VERSION_GE_1_5
contiguous_memory = contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
(contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif #endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
"A tensor was not on the same device as the first tensor"); TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
"Size mismatch");
} }
} }
@ -77,16 +78,17 @@ void multi_tensor_apply(
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) { for (int t = 0; t < ntensors; t++)
{
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++) for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++; loc_tensor_info++;
int chunks_this_tensor = int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { for (int chunk = 0; chunk < chunks_this_tensor; chunk++)
{
// std::cout << chunks_this_tensor << std::endl; // std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk; tl.block_to_chunk[loc_block_info] = chunk;
@ -96,23 +98,29 @@ void multi_tensor_apply(
chunk == chunks_this_tensor - 1); chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) { if (tensors_full || blocks_full || last_chunk)
{
// using accscalar_t = acc_type<scalar_t, true>; // using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>( multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...); chunk_size,
noop_flag.DATA_PTR<int>(),
tl,
callable,
args...);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt. // Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0; loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) { if (chunk == chunks_this_tensor - 1)
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 {
// << std::endl; // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0; loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1; tl.start_tensor_this_launch = t + 1;
} else { }
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 else
// << std::endl; {
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++) for (int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];

View File

@ -3,24 +3,31 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_masked_softmax { namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor); float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results, torch::Tensor const& softmax_results,
float scale_factor); float scale_factor);
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads); int attn_heads);
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) { float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
@ -31,8 +38,11 @@ torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
return fwd_cuda(input, mask, scale_factor); return fwd_cuda(input, mask, scale_factor);
} }
torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor bwd(
torch::Tensor const& softmax_results, float scale_factor) { torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
@ -46,10 +56,12 @@ torch::Tensor bwd(torch::Tensor const& output_grads,
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) { int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
attn_heads);
} }
} // end namespace scaled_masked_softmax } // end namespace scaled_masked_softmax
@ -57,14 +69,16 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
} // end namespace multihead_attn } // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward."); "Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block", m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax:: &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
get_batch_per_block, "Return Batch per block size."
"Return Batch per block size."); );
} }

View File

@ -4,12 +4,12 @@
#pragma once #pragma once
#include <assert.h> #include <assert.h>
#include <c10/macros/Macros.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace { namespace {
@ -17,40 +17,22 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <> template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>( __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <> template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>( __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <> template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
const c10::Half *src) {
*dst = *src;
}
template <> template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
const uint8_t *src) {
*dst = *src;
}
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
int log2_ceil(int value) { int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
@ -58,12 +40,14 @@ int log2_ceil(int value) {
return log2_value; return log2_value;
} }
template <typename T> template<typename T>
struct Add { struct Add {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; } __device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
}; };
template <typename T> template<typename T>
struct Max { struct Max {
__device__ __forceinline__ T operator()(T a, T b) const { __device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a; return a < b ? b : a;
@ -71,9 +55,8 @@ struct Max {
}; };
template <typename T> template <typename T>
__device__ __forceinline__ T __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, {
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width); return __shfl_xor_sync(mask, value, laneMask, width);
#else #else
@ -81,13 +64,12 @@ WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
#endif #endif
} }
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
template <typename> class ReduceOp> __device__ __forceinline__ void warp_reduce(acc_t* sum) {
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
ReduceOp<acc_t> r; ReduceOp<acc_t> r;
#pragma unroll #pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b); sum[i] = r(sum[i], b);
@ -96,35 +78,34 @@ __device__ __forceinline__ void warp_reduce(acc_t *sum) {
} }
/* /*
* Extended softmax (from native aten pytorch) with following additional * Extended softmax (from native aten pytorch) with following additional features
* features 1) input scaling 2) Explicit masking * 1) input scaling
* 2) Explicit masking
*/ */
template <typename input_t, typename output_t, typename acc_t, template <typename input_t, typename output_t, typename acc_t, int log2_elements>
int log2_elements>
__global__ void scaled_masked_softmax_warp_forward( __global__ void scaled_masked_softmax_warp_forward(
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, output_t *dst,
int micro_batch_size, int element_count, int pad_batches) { const input_t *src,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int element_count,
int pad_batches)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel. // warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements; constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
(blockDim.y *
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
threadIdx.y) *
WARP_BATCH;
int pad_first_batch = 0; int pad_first_batch = 0;
if (pad_batches != 1) { // bert style if (pad_batches != 1) { // bert style
pad_first_batch = pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
WARP_BATCH;
} else { // gpt2 style } else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
} }
@ -132,10 +113,10 @@ __global__ void scaled_masked_softmax_warp_forward(
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the // there might be multiple batches per warp. compute the index within the batch
// batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
@ -146,20 +127,20 @@ __global__ void scaled_masked_softmax_warp_forward(
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG]; input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
int itr_idx = i * element_count + it * WARP_SIZE; int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx); copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) { if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale; elements[i][it + element] = (acc_t)temp_data[element] * scale;
@ -168,7 +149,7 @@ __global__ void scaled_masked_softmax_warp_forward(
} }
} }
} else { } else {
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
} }
@ -178,21 +159,20 @@ __global__ void scaled_masked_softmax_warp_forward(
// compute max_value // compute max_value
acc_t max_value[WARP_BATCH]; acc_t max_value[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0]; max_value[i] = elements[i][0];
#pragma unroll #pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
} }
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH]{0.0f}; acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i])); elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it]; sum[i] += elements[i][it];
@ -202,19 +182,19 @@ __global__ void scaled_masked_softmax_warp_forward(
// store result // store result
output_t out[ELEMENTS_PER_LDG_STG]; output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break; if (i >= local_batches)
#pragma unroll break;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i]; out[element] = elements[i][it + element] / sum[i];
} }
copy_vector<output_t, ELEMENTS_PER_LDG_STG>( copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
dst + i * element_count + it * WARP_SIZE, out);
} else { } else {
break; break;
} }
@ -222,16 +202,19 @@ __global__ void scaled_masked_softmax_warp_forward(
} }
} }
template <typename input_t, typename output_t, typename acc_t, template <typename input_t, typename output_t, typename acc_t, int log2_elements>
int log2_elements>
__global__ void scaled_masked_softmax_warp_backward( __global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, output_t *gradInput,
int micro_batch_size, int element_count) { input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel. // warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements; constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
@ -243,92 +226,84 @@ __global__ void scaled_masked_softmax_warp_backward(
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the // there might be multiple batches per warp. compute the index within the batch
// batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
// load data from global memory // load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>( copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
temp_grad, grad + i * element_count + it * WARP_SIZE); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count + it * WARP_SIZE);
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element]; output_reg[i][it + element] = (acc_t)temp_output[element];
} }
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] = grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
(acc_t)temp_grad[element] * output_reg[i][it + element];
} }
} }
} }
} }
acc_t sum[WARP_BATCH]; acc_t sum[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0]; sum[i] = grad_reg[i][0];
#pragma unroll #pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it]; sum[i] += grad_reg[i][it];
} }
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result // store result
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break; if (i >= local_batches)
#pragma unroll break;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
// compute gradients // compute gradients
output_t out[ELEMENTS_PER_LDG_STG]; output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
} }
copy_vector<output_t, ELEMENTS_PER_LDG_STG>( copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
gradInput + i * element_count + it * WARP_SIZE, out);
} }
} }
} }
} }
} // end of anonymous namespace } // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int attn_heads) {
int log2_elements = log2_ceil(key_seq_len); int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements; const int next_power_of_two = 1 << log2_elements;
int warp_size = int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
@ -338,14 +313,19 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
return batches_per_block; return batches_per_block;
} }
template <typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, void dispatch_scaled_masked_softmax_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask, const uint8_t *mask,
const input_t scale, const input_t scale,
int query_seq_len, int key_seq_len, int query_seq_len,
int batches, int attn_heads, int key_seq_len,
int pad_batches) { int batches,
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
@ -353,13 +333,10 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
const int next_power_of_two = 1 << log2_elements; const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len; int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
// softmax_warp_forward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
// softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization // use 128 threads per block to maximimize gpu utilization
@ -367,70 +344,58 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 1: // 2 case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break; break;
default: default:
break; break;
@ -438,14 +403,18 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
} }
} }
template <typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(output_t *grad_input, void dispatch_scaled_masked_softmax_backward(
output_t *grad_input,
input_t *grad, input_t *grad,
const input_t *output, const input_t *output,
const acc_t scale, const acc_t scale,
int query_seq_len, int key_seq_len, int query_seq_len,
int batches, int attn_heads) { int key_seq_len,
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
@ -453,13 +422,10 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
const int next_power_of_two = 1 << log2_elements; const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len; int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
// softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization // use 128 threads per block to maximimize gpu utilization
@ -467,69 +433,57 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count / batches_per_block; int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 1: // 2 case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
grad_input, grad, output, scale, batch_count, key_seq_len);
break; break;
default: default:
break; break;

View File

@ -4,12 +4,11 @@
#pragma once #pragma once
#include <assert.h> #include <assert.h>
#include <c10/macros/Macros.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace { namespace {
@ -17,65 +16,38 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <> template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>( __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <> template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>( __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <> template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
const c10::Half *src) {
*dst = *src;
}
template <> template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
const uint8_t *src) {
*dst = *src;
}
template <> template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
template <typename Datatype, int ELEMENTS_PER_LDG> template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst); __device__ __inline__ void copy_zero_vector(Datatype *dst);
template <> template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>( __device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
c10::BFloat16 *dst) {
*dst = 0.0;
}
template <> template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>( __device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
c10::BFloat16 *dst) {
*((float2 *)dst) = make_float2(0.0f, 0.0f);
}
template <> template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { __device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
*dst = 0.0;
}
template <> template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { __device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
*((float2 *)dst) = make_float2(0.0f, 0.0f);
}
int log2_ceil(int value) { int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
@ -83,12 +55,14 @@ int log2_ceil(int value) {
return log2_value; return log2_value;
} }
template <typename T> template<typename T>
struct Add { struct Add {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; } __device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
}; };
template <typename T> template<typename T>
struct Max { struct Max {
__device__ __forceinline__ T operator()(T a, T b) const { __device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a; return a < b ? b : a;
@ -96,9 +70,8 @@ struct Max {
}; };
template <typename T> template <typename T>
__device__ __forceinline__ T __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, {
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width); return __shfl_xor_sync(mask, value, laneMask, width);
#else #else
@ -106,13 +79,12 @@ WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
#endif #endif
} }
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
template <typename> class ReduceOp> __device__ __forceinline__ void warp_reduce(acc_t* sum) {
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
ReduceOp<acc_t> r; ReduceOp<acc_t> r;
#pragma unroll #pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b); sum[i] = r(sum[i], b);
@ -121,37 +93,38 @@ __device__ __forceinline__ void warp_reduce(acc_t *sum) {
} }
/* /*
* Extended softmax (from native aten pytorch) with following additional * Extended softmax (from native aten pytorch) with following additional features
* features 1) input scaling 2) Implicit time (diagonal masking) * 1) input scaling
* 2) Implicit time (diagonal masking)
*/ */
template <typename input_t, typename output_t, typename acc_t, template <typename input_t, typename output_t, typename acc_t, int log2_elements>
int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward( __global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, output_t *dst,
int stride, int element_count) { const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel. // warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements; constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
int warp_iteration_limit = int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the // there might be multiple batches per warp. compute the index within the batch
// batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
@ -160,28 +133,27 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG]; input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq; int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>( copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
temp_data, src + i * element_count * stride + it * WARP_SIZE);
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) { if ((element_index + element) < batch_element_count) {
elements[i][it + element] = (acc_t)temp_data[element] * scale; elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else { } else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
} }
} }
} else { } else {
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity(); elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
} }
@ -191,21 +163,20 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// compute max_value // compute max_value
acc_t max_value[WARP_BATCH]; acc_t max_value[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0]; max_value[i] = elements[i][0];
#pragma unroll #pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
} }
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH]{0.0f}; acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) { if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i])); elements[i][it] = std::exp((elements[i][it] - max_value[i]));
@ -217,15 +188,17 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// store result // store result
output_t out[ELEMENTS_PER_LDG_STG]; output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break; if (i >= local_batches)
#pragma unroll break;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) { if (element_index < local_seq) {
#pragma unroll
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) { if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i]; out[element] = elements[i][it + element] / sum[i];
@ -233,11 +206,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
out[element] = 0; out[element] = 0;
} }
} }
copy_vector<output_t, ELEMENTS_PER_LDG_STG>( copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) { } else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>( copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
dst + i * element_count * stride + it * WARP_SIZE);
} else { } else {
break; break;
} }
@ -245,32 +216,34 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
} }
} }
template <typename input_t, typename output_t, typename acc_t, template <typename input_t, typename output_t, typename acc_t, int log2_elements>
int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward( __global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, output_t *gradInput,
int micro_batch_size, int stride, int element_count) { input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel. // warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements; constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch; int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the // there might be multiple batches per warp. compute the index within the batch
// batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
@ -280,34 +253,31 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
gradInput += thread_offset; gradInput += thread_offset;
// load data from global memory // load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG]; input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG]; input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq; int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>( copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
temp_grad, grad + i * element_count * stride + it * WARP_SIZE); copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) { if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element]; output_reg[i][it + element] = (acc_t)temp_output[element];
} }
} }
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) { if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
(acc_t)temp_grad[element] * output_reg[i][it + element];
} }
} }
} }
@ -315,34 +285,32 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
} }
acc_t sum[WARP_BATCH]; acc_t sum[WARP_BATCH];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0]; sum[i] = grad_reg[i][0];
#pragma unroll #pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it]; sum[i] += grad_reg[i][it];
} }
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result // store result
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break; if (i >= local_batches)
#pragma unroll break;
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
// compute gradients // compute gradients
output_t out[ELEMENTS_PER_LDG_STG]; output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
} }
copy_vector<output_t, ELEMENTS_PER_LDG_STG>( copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
gradInput + i * element_count * stride + it * WARP_SIZE, out);
} }
} }
} }
@ -350,11 +318,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
} // end of anonymous namespace } // end of anonymous namespace
template <typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward( void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst, const input_t *src, const input_t scale, output_t *dst,
int softmax_elements, int softmax_elements_stride, int attn_batches) { const input_t *src,
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
@ -363,13 +336,10 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int seq_len = softmax_elements; int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len; int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
// softmax_warp_forward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
// softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization // use 128 threads per block to maximimize gpu utilization
@ -385,88 +355,52 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
acc_t, 0> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 1: // 2 case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
acc_t, 1> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 2: // 4 case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
acc_t, 2> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 3: // 8 case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
acc_t, 3> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 4: // 16 case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
acc_t, 4> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 5: // 32 case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
acc_t, 5> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 6: // 64 case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
acc_t, 6> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 7: // 128 case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
acc_t, 7> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 8: // 256 case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
acc_t, 8> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 9: // 512 case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
acc_t, 9> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 10: // 1024 case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
acc_t, 10> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
case 11: // 2048 case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
acc_t, 11> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break; break;
default: default:
break; break;
@ -474,12 +408,17 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
} }
} }
template <typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward( void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input, input_t *grad, const input_t *output, output_t *grad_input,
const acc_t scale, int softmax_elements, int softmax_elements_stride, input_t *grad,
int attn_batches) { const input_t *output,
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
@ -488,13 +427,10 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int seq_len = softmax_elements; int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len; int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
// softmax_warp_backward. int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization // use 128 threads per block to maximimize gpu utilization
@ -510,88 +446,52 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) { switch (log2_elements) {
case 0: // 1 case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
acc_t, 0> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 1: // 2 case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
acc_t, 1> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 2: // 4 case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
acc_t, 2> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 3: // 8 case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
acc_t, 3> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 4: // 16 case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
acc_t, 4> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 5: // 32 case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
acc_t, 5> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 6: // 64 case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
acc_t, 6> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 7: // 128 case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
acc_t, 7> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 8: // 256 case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
acc_t, 8> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 9: // 512 case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
acc_t, 9> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 10: // 1024 case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
acc_t, 10> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
case 11: // 2048 case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
acc_t, 11> <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break; break;
default: default:
break; break;

View File

@ -1,15 +1,18 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \ switch(TYPE) \
case at::ScalarType::Half: { \ { \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: { \ case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \ using scalar_t = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -18,22 +21,30 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \ switch(TYPEIN) \
case at::ScalarType::Float: { \ { \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \ using scalar_t_in = float; \
switch (TYPEOUT) { \ switch(TYPEOUT) \
case at::ScalarType::Float: { \ { \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \ using scalar_t_out = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: { \ case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \ using scalar_t_out = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: { \ case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \ using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -43,13 +54,15 @@
} \ } \
break; \ break; \
} \ } \
case at::ScalarType::Half: { \ case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \ using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \ using scalar_t_out = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: { \ case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \ using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \ using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
@ -73,13 +86,16 @@
// }; // };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \ switch (TYPE) \
case at::ScalarType::Float: { \ { \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: { \ case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \ using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -89,18 +105,22 @@
} }
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \ switch (TYPE) \
case at::ScalarType::Float: { \ { \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: { \ case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \ using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Byte: { \ case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \ using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -110,18 +130,22 @@
} }
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \ switch (TYPE) \
case at::ScalarType::Double: { \ { \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \ using scalar_t_##LEVEL = double; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Float: { \ case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: { \ case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \ using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -131,13 +155,16 @@
} }
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \ switch (TYPE) \
case at::ScalarType::Double: { \ { \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \ using scalar_t_##LEVEL = double; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Float: { \ case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -147,52 +174,62 @@
} }
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \ if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = float; \ using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \ } \
PTYPE == at::ScalarType::Half) { \ else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = float; \ using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && \ } \
PTYPE == at::ScalarType::Float) { \ else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
{ \
using g_scalar_t_##LEVEL = at::Half; \ using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \ } \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = at::Half; \ using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
} else { \ } \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ else \
"'"); \ { \
} AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
} \
template <typename T> template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes( __device__ __forceinline__ T reduce_block_into_lanes(T *x,
T *x, T val, int lanes = 1, T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32. bool share_result = false) // lanes is intended to be <= 32.
{ {
int tid = threadIdx.x + threadIdx.y * blockDim.x; int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) { if (blockSize >= 64)
{
x[tid] = val; x[tid] = val;
__syncthreads(); __syncthreads();
} }
#pragma unroll #pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) { for (int i = (blockSize >> 1); i >= 64; i >>= 1)
if (tid < i) x[tid] = x[tid] + x[tid + i]; {
if (tid < i)
x[tid] = x[tid] + x[tid + i];
__syncthreads(); __syncthreads();
} }
T final; T final;
if (tid < 32) { if (tid < 32)
{
if (blockSize >= 64) if (blockSize >= 64)
final = x[tid] + x[tid + 32]; final = x[tid] + x[tid + 32];
else else
@ -204,8 +241,10 @@ __device__ __forceinline__ T reduce_block_into_lanes(
final = final + __shfl_down_sync(0xffffffff, final, i); final = final + __shfl_down_sync(0xffffffff, final, i);
} }
if (share_result) { if (share_result)
if (tid < lanes) x[tid] = final; // EpilogueOp {
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps. // Make sure the smem result is visible to all warps.
__syncthreads(); __syncthreads();
} }
@ -214,28 +253,32 @@ __device__ __forceinline__ T reduce_block_into_lanes(
} }
template <typename T> template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op( __device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
T *x, T val, int lanes = 1, T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32. bool share_result = false) // lanes is intended to be <= 32.
{ {
int tid = threadIdx.x + threadIdx.y * blockDim.x; int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) { if (blockSize >= 64)
{
x[tid] = val; x[tid] = val;
__syncthreads(); __syncthreads();
} }
#pragma unroll #pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) { for (int i = (blockSize >> 1); i >= 64; i >>= 1)
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); {
if (tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads(); __syncthreads();
} }
T final; T final;
if (tid < 32) { if (tid < 32)
{
if (blockSize >= 64) if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else else
@ -244,12 +287,13 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op(
#pragma unroll #pragma unroll
for (int i = 16; i >= lanes; i >>= 1) for (int i = 16; i >= lanes; i >>= 1)
final = final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
} }
if (share_result) { if (share_result)
if (tid < lanes) x[tid] = final; // EpilogueOp {
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps. // Make sure the smem result is visible to all warps.
__syncthreads(); __syncthreads();
} }