mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h code style (#1263)
This commit is contained in:
parent
197a2c89e2
commit
5d7366b144
@ -4,12 +4,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <stdint.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -17,22 +17,40 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
|
|||||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
|
||||||
|
const uint8_t *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
|
||||||
|
const uint8_t *src) {
|
||||||
|
*((half2 *)dst) = *((half2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
int log2_ceil(int value) {
|
int log2_ceil(int value) {
|
||||||
int log2_value = 0;
|
int log2_value = 0;
|
||||||
@ -42,9 +60,7 @@ int log2_ceil(int value) {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Add {
|
struct Add {
|
||||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -55,8 +71,9 @@ struct Max {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
__device__ __forceinline__ T
|
||||||
{
|
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
||||||
|
unsigned int mask = 0xffffffff) {
|
||||||
#if CUDA_VERSION >= 9000
|
#if CUDA_VERSION >= 9000
|
||||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||||
#else
|
#else
|
||||||
@ -64,7 +81,8 @@ __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int wid
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
||||||
|
template <typename> class ReduceOp>
|
||||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
||||||
ReduceOp<acc_t> r;
|
ReduceOp<acc_t> r;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -78,34 +96,35 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Extended softmax (from native aten pytorch) with following additional features
|
* Extended softmax (from native aten pytorch) with following additional
|
||||||
* 1) input scaling
|
* features 1) input scaling 2) Explicit masking
|
||||||
* 2) Explicit masking
|
|
||||||
*/
|
*/
|
||||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
template <typename input_t, typename output_t, typename acc_t,
|
||||||
|
int log2_elements>
|
||||||
__global__ void scaled_masked_softmax_warp_forward(
|
__global__ void scaled_masked_softmax_warp_forward(
|
||||||
output_t *dst,
|
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
|
||||||
const input_t *src,
|
int micro_batch_size, int element_count, int pad_batches) {
|
||||||
const uint8_t *mask,
|
|
||||||
const acc_t scale,
|
|
||||||
int micro_batch_size,
|
|
||||||
int element_count,
|
|
||||||
int pad_batches)
|
|
||||||
{
|
|
||||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
// warp_size of method warp_softmax_forward_kernel.
|
// warp_size of method warp_softmax_forward_kernel.
|
||||||
constexpr int next_power_of_two = 1 << log2_elements;
|
constexpr int next_power_of_two = 1 << log2_elements;
|
||||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
constexpr int WARP_SIZE =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
|
|
||||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||||
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
|
int first_batch =
|
||||||
|
(blockDim.y *
|
||||||
|
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
|
||||||
|
threadIdx.y) *
|
||||||
|
WARP_BATCH;
|
||||||
int pad_first_batch = 0;
|
int pad_first_batch = 0;
|
||||||
if (pad_batches != 1) { // bert style
|
if (pad_batches != 1) { // bert style
|
||||||
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
|
pad_first_batch =
|
||||||
|
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
|
||||||
|
WARP_BATCH;
|
||||||
} else { // gpt2 style
|
} else { // gpt2 style
|
||||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||||
}
|
}
|
||||||
@ -113,10 +132,10 @@ __global__ void scaled_masked_softmax_warp_forward(
|
|||||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||||
// many batches have to computed within this WARP.
|
// many batches have to computed within this WARP.
|
||||||
int local_batches = micro_batch_size - first_batch;
|
int local_batches = micro_batch_size - first_batch;
|
||||||
if (local_batches > WARP_BATCH)
|
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||||
local_batches = WARP_BATCH;
|
|
||||||
|
|
||||||
// there might be multiple batches per warp. compute the index within the batch
|
// there might be multiple batches per warp. compute the index within the
|
||||||
|
// batch
|
||||||
int local_idx = threadIdx.x;
|
int local_idx = threadIdx.x;
|
||||||
|
|
||||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
@ -164,7 +183,8 @@ __global__ void scaled_masked_softmax_warp_forward(
|
|||||||
max_value[i] = elements[i][0];
|
max_value[i] = elements[i][0];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||||
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
max_value[i] =
|
||||||
|
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||||
@ -184,8 +204,7 @@ __global__ void scaled_masked_softmax_warp_forward(
|
|||||||
output_t out[ELEMENTS_PER_LDG_STG];
|
output_t out[ELEMENTS_PER_LDG_STG];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
if (i >= local_batches)
|
if (i >= local_batches) break;
|
||||||
break;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
@ -194,7 +213,8 @@ __global__ void scaled_masked_softmax_warp_forward(
|
|||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
out[element] = elements[i][it + element] / sum[i];
|
out[element] = elements[i][it + element] / sum[i];
|
||||||
}
|
}
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
dst + i * element_count + it * WARP_SIZE, out);
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -202,19 +222,16 @@ __global__ void scaled_masked_softmax_warp_forward(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
|
template <typename input_t, typename output_t, typename acc_t,
|
||||||
|
int log2_elements>
|
||||||
__global__ void scaled_masked_softmax_warp_backward(
|
__global__ void scaled_masked_softmax_warp_backward(
|
||||||
output_t *gradInput,
|
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||||
input_t *grad,
|
int micro_batch_size, int element_count) {
|
||||||
const input_t *output,
|
|
||||||
acc_t scale,
|
|
||||||
int micro_batch_size,
|
|
||||||
int element_count)
|
|
||||||
{
|
|
||||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
// warp_size of method warp_softmax_backward_kernel.
|
// warp_size of method warp_softmax_backward_kernel.
|
||||||
constexpr int next_power_of_two = 1 << log2_elements;
|
constexpr int next_power_of_two = 1 << log2_elements;
|
||||||
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
constexpr int WARP_SIZE =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
@ -226,14 +243,15 @@ __global__ void scaled_masked_softmax_warp_backward(
|
|||||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||||
// many batches have to computed within this WARP.
|
// many batches have to computed within this WARP.
|
||||||
int local_batches = micro_batch_size - first_batch;
|
int local_batches = micro_batch_size - first_batch;
|
||||||
if (local_batches > WARP_BATCH)
|
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||||
local_batches = WARP_BATCH;
|
|
||||||
|
|
||||||
// there might be multiple batches per warp. compute the index within the batch
|
// there might be multiple batches per warp. compute the index within the
|
||||||
|
// batch
|
||||||
int local_idx = threadIdx.x;
|
int local_idx = threadIdx.x;
|
||||||
|
|
||||||
// the first element to process by the current thread
|
// the first element to process by the current thread
|
||||||
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
int thread_offset =
|
||||||
|
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
grad += thread_offset;
|
grad += thread_offset;
|
||||||
output += thread_offset;
|
output += thread_offset;
|
||||||
gradInput += thread_offset;
|
gradInput += thread_offset;
|
||||||
@ -251,8 +269,10 @@ __global__ void scaled_masked_softmax_warp_backward(
|
|||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
if (element_index < batch_element_count) {
|
if (element_index < batch_element_count) {
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
|
temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||||
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
temp_output, output + i * element_count + it * WARP_SIZE);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
@ -260,7 +280,8 @@ __global__ void scaled_masked_softmax_warp_backward(
|
|||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
|
grad_reg[i][it + element] =
|
||||||
|
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -280,8 +301,7 @@ __global__ void scaled_masked_softmax_warp_backward(
|
|||||||
// store result
|
// store result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
if (i >= local_batches)
|
if (i >= local_batches) break;
|
||||||
break;
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
@ -290,20 +310,25 @@ __global__ void scaled_masked_softmax_warp_backward(
|
|||||||
output_t out[ELEMENTS_PER_LDG_STG];
|
output_t out[ELEMENTS_PER_LDG_STG];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
|
out[element] =
|
||||||
|
(output_t)(scale * (grad_reg[i][it + element] -
|
||||||
|
output_reg[i][it + element] * sum[i]));
|
||||||
}
|
}
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
gradInput + i * element_count + it * WARP_SIZE, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // end of anonymous namespace
|
} // end of anonymous namespace
|
||||||
|
|
||||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
|
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||||
|
int attn_heads) {
|
||||||
int log2_elements = log2_ceil(key_seq_len);
|
int log2_elements = log2_ceil(key_seq_len);
|
||||||
const int next_power_of_two = 1 << log2_elements;
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
|
|
||||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
constexpr int threads_per_block = 128;
|
constexpr int threads_per_block = 128;
|
||||||
@ -314,17 +339,12 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t>
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
void dispatch_scaled_masked_softmax_forward(
|
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
|
||||||
output_t *dst,
|
|
||||||
const input_t *src,
|
|
||||||
const uint8_t *mask,
|
const uint8_t *mask,
|
||||||
const input_t scale,
|
const input_t scale,
|
||||||
int query_seq_len,
|
int query_seq_len, int key_seq_len,
|
||||||
int key_seq_len,
|
int batches, int attn_heads,
|
||||||
int batches,
|
int pad_batches) {
|
||||||
int attn_heads,
|
|
||||||
int pad_batches)
|
|
||||||
{
|
|
||||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||||
if (key_seq_len == 0) {
|
if (key_seq_len == 0) {
|
||||||
return;
|
return;
|
||||||
@ -333,10 +353,13 @@ void dispatch_scaled_masked_softmax_forward(
|
|||||||
const int next_power_of_two = 1 << log2_elements;
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
int batch_count = batches * attn_heads * query_seq_len;
|
int batch_count = batches * attn_heads * query_seq_len;
|
||||||
|
|
||||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
|
// This value must match the WARP_SIZE constexpr value computed inside
|
||||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
// softmax_warp_forward.
|
||||||
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
|
||||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
|
// This value must match the WARP_BATCH constexpr value computed inside
|
||||||
|
// softmax_warp_forward.
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
// use 128 threads per block to maximimize gpu utilization
|
// use 128 threads per block to maximimize gpu utilization
|
||||||
@ -351,51 +374,63 @@ void dispatch_scaled_masked_softmax_forward(
|
|||||||
switch (log2_elements) {
|
switch (log2_elements) {
|
||||||
case 0: // 1
|
case 0: // 1
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 1: // 2
|
case 1: // 2
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 2: // 4
|
case 2: // 4
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 3: // 8
|
case 3: // 8
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 4: // 16
|
case 4: // 16
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 5: // 32
|
case 5: // 32
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 6: // 64
|
case 6: // 64
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 7: // 128
|
case 7: // 128
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 8: // 256
|
case 8: // 256
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 9: // 512
|
case 9: // 512
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 10: // 1024
|
case 10: // 1024
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
case 11: // 2048
|
case 11: // 2048
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
@ -404,16 +439,12 @@ void dispatch_scaled_masked_softmax_forward(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t>
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
void dispatch_scaled_masked_softmax_backward(
|
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
|
||||||
output_t *grad_input,
|
|
||||||
input_t *grad,
|
input_t *grad,
|
||||||
const input_t *output,
|
const input_t *output,
|
||||||
const acc_t scale,
|
const acc_t scale,
|
||||||
int query_seq_len,
|
int query_seq_len, int key_seq_len,
|
||||||
int key_seq_len,
|
int batches, int attn_heads) {
|
||||||
int batches,
|
|
||||||
int attn_heads)
|
|
||||||
{
|
|
||||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||||
if (key_seq_len == 0) {
|
if (key_seq_len == 0) {
|
||||||
return;
|
return;
|
||||||
@ -422,10 +453,13 @@ void dispatch_scaled_masked_softmax_backward(
|
|||||||
const int next_power_of_two = 1 << log2_elements;
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
int batch_count = batches * attn_heads * query_seq_len;
|
int batch_count = batches * attn_heads * query_seq_len;
|
||||||
|
|
||||||
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
|
// This value must match the WARP_SIZE constexpr value computed inside
|
||||||
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
// softmax_warp_backward.
|
||||||
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
|
||||||
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
|
// This value must match the WARP_BATCH constexpr value computed inside
|
||||||
|
// softmax_warp_backward.
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
// use 128 threads per block to maximimize gpu utilization
|
// use 128 threads per block to maximimize gpu utilization
|
||||||
@ -439,51 +473,63 @@ void dispatch_scaled_masked_softmax_backward(
|
|||||||
switch (log2_elements) {
|
switch (log2_elements) {
|
||||||
case 0: // 1
|
case 0: // 1
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 1: // 2
|
case 1: // 2
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 2: // 4
|
case 2: // 4
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 3: // 8
|
case 3: // 8
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 4: // 16
|
case 4: // 16
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 5: // 32
|
case 5: // 32
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 6: // 64
|
case 6: // 64
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 7: // 128
|
case 7: // 128
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 8: // 256
|
case 8: // 256
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 9: // 512
|
case 9: // 512
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 10: // 1024
|
case 10: // 1024
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
case 11: // 2048
|
case 11: // 2048
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
Loading…
Reference in New Issue
Block a user