mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
refactor csrc (#5582)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vector_copy_utils.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t, bool Aligned, int VecSize>
|
template<typename scalar_t, bool Aligned, int VecSize>
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vector_copy_utils.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t, bool Aligned, int VecSize>
|
template<typename scalar_t, bool Aligned, int VecSize>
|
||||||
|
@@ -16,8 +16,10 @@ namespace funcs {
|
|||||||
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
||||||
|
|
||||||
// Note(LiuYang): This file provides base math operation for data type
|
// Note(LiuYang): This file provides base math operation for data type
|
||||||
// include POD and cuda built-in type such as half and __nv_bfloat16
|
// include POD and cuda built-in type such as half and __nv_bfloat16.
|
||||||
template <typename LT, typename RT, typename RET, BinaryOpType Op>
|
// Implementation of common and simple binary operators should be placed here,
|
||||||
|
// otherwise, they should be placed in a new file under functors dir.
|
||||||
|
template <typename LT, typename RT, typename RET, BinaryOpType op_type>
|
||||||
struct BinaryOpFunctor;
|
struct BinaryOpFunctor;
|
||||||
|
|
||||||
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
|
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
|
@@ -16,32 +16,6 @@ namespace colossalAI {
|
|||||||
namespace cuda {
|
namespace cuda {
|
||||||
namespace funcs {
|
namespace funcs {
|
||||||
|
|
||||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
|
||||||
template <typename T>
|
|
||||||
struct TypeConverter {
|
|
||||||
using Type = half2;
|
|
||||||
}; // keep for generality
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<half2> {
|
|
||||||
using Type = at::Half;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<at::Half> {
|
|
||||||
using Type = half2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<__nv_bfloat162> {
|
|
||||||
using Type = at::BFloat16;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TypeConverter<at::BFloat16> {
|
|
||||||
using Type = __nv_bfloat162;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename From, typename To>
|
template <typename From, typename To>
|
||||||
struct CastFunctor : public std::unary_function<From, To> {
|
struct CastFunctor : public std::unary_function<From, To> {
|
||||||
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
||||||
|
46
extensions/csrc/cuda/funcs/unary_functor.h
Normal file
46
extensions/csrc/cuda/funcs/unary_functor.h
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "../utils/micros.h"
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace cuda {
|
||||||
|
namespace funcs {
|
||||||
|
|
||||||
|
// Note(LiuYang): As a retrieved table to check which operation is supported
|
||||||
|
// already
|
||||||
|
enum class UnaryOpType { kLog2Ceil = 0 };
|
||||||
|
|
||||||
|
// Note(LiuYang): Implementation of common and simple unary operators should be
|
||||||
|
// placed here, otherwise, they should be placed in a new file under functors
|
||||||
|
// dir.
|
||||||
|
template <typename From, typename To, UnaryOpType op_type>
|
||||||
|
struct UnaryOpFunctor;
|
||||||
|
|
||||||
|
#define COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( \
|
||||||
|
FROM, TO, UNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
|
||||||
|
template <ARGS> \
|
||||||
|
struct UnaryOpFunctor<FROM, TO, UNARY_OP_TYPE> \
|
||||||
|
: public std::unary_function<FROM, TO> { \
|
||||||
|
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||||
|
};
|
||||||
|
|
||||||
|
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil,
|
||||||
|
HOSTDEVICE, {
|
||||||
|
int log2_value = 0;
|
||||||
|
while ((1 << log2_value) < val)
|
||||||
|
++log2_value;
|
||||||
|
return log2_value;
|
||||||
|
})
|
||||||
|
|
||||||
|
#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION
|
||||||
|
|
||||||
|
} // namespace funcs
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace colossalAI
|
@@ -2,7 +2,7 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vector_copy_utils.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
#include "../common/mp_type_traits.h"
|
#include "../common/mp_type_traits.h"
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vector_copy_utils.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
#include "stdio.h"
|
#include "stdio.h"
|
||||||
|
|
||||||
|
@@ -4,7 +4,7 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include "../funcs/op_functor.h"
|
#include "../funcs/binary_functor.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
@@ -12,7 +12,6 @@ namespace utils {
|
|||||||
|
|
||||||
const float kReduceFloatInfNeg = -100000000.f;
|
const float kReduceFloatInfNeg = -100000000.f;
|
||||||
const float kReduceFloatInfPos = 100000000.f;
|
const float kReduceFloatInfPos = 100000000.f;
|
||||||
const int kWarpSize = 32;
|
|
||||||
const unsigned int kWarpReduceMask = 0xffffffff;
|
const unsigned int kWarpReduceMask = 0xffffffff;
|
||||||
|
|
||||||
enum class ReduceType { kMax = 0, kSum };
|
enum class ReduceType { kMax = 0, kSum };
|
||||||
@@ -31,44 +30,42 @@ struct GetOpForReduceType<T, ReduceType::kSum> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
|
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
|
||||||
for (int offset = 0; offset < LANES; ++offset) { \
|
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
|
||||||
*(VAL_PTR + offset) = \
|
*(VAL_PTR + offset) = \
|
||||||
OP(*(VAL_PTR + offset), \
|
OP(*(VAL_PTR + offset), \
|
||||||
__shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
|
__shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \
|
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, WIDTH, OP, LANES) \
|
||||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \
|
_Pragma("unroll") for (int DELTA = (WIDTH >> 1); DELTA > 0; DELTA >>= 1) { \
|
||||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \
|
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
|
||||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \
|
}
|
||||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \
|
|
||||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES)
|
|
||||||
|
|
||||||
#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \
|
#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, VAL_PTR, OP, LANES, DEFAULT_VALUE, \
|
||||||
DEFAULT_VALUE, REDUCE_TYPE) \
|
REDUCE_TYPE) \
|
||||||
__shared__ T shm[LANES][32]; \
|
__shared__ T shm[LANES][32]; \
|
||||||
int lane_id = threadIdx.x & 0x1f; \
|
int lane_id = threadIdx.x & 0x1f; \
|
||||||
int warp_id = threadIdx.x >> 5; \
|
int warp_id = threadIdx.x >> 5; \
|
||||||
\
|
\
|
||||||
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
|
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
|
||||||
if (lane_id == 0) { \
|
if (lane_id == 0) { \
|
||||||
for (int offset = 0; offset < LANES; ++offset) { \
|
for (int offset = 0; offset < LANES; ++offset) { \
|
||||||
shm[offset][warp_id] = *(VAL_PTR + offset); \
|
shm[offset][warp_id] = *(VAL_PTR + offset); \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
__syncthreads(); \
|
__syncthreads(); \
|
||||||
\
|
\
|
||||||
for (int offset = 0; offset < LANES; ++offset) { \
|
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
|
||||||
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
|
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
|
||||||
? shm[offset][lane_id] \
|
? shm[offset][lane_id] \
|
||||||
: static_cast<T>(DEFAULT_VALUE); \
|
: static_cast<T>(DEFAULT_VALUE); \
|
||||||
} \
|
} \
|
||||||
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);
|
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);
|
||||||
|
|
||||||
template <typename T, ReduceType rtype, int lanes>
|
template <typename T, ReduceType rtype, int lanes, int width = 32>
|
||||||
__forceinline__ __device__ void warp_reduce(T* pval) {
|
__forceinline__ __device__ void warp_reduce(T* pval) {
|
||||||
typename GetOpForReduceType<T, rtype>::Op op;
|
typename GetOpForReduceType<T, rtype>::Op op;
|
||||||
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes);
|
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, width, op, lanes);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, ReduceType rtype>
|
template <typename T, ReduceType rtype>
|
||||||
@@ -84,8 +81,7 @@ template <typename T, ReduceType rtype, int lanes>
|
|||||||
__forceinline__ __device__ void block_reduce(T* pval) {
|
__forceinline__ __device__ void block_reduce(T* pval) {
|
||||||
constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
|
constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
|
||||||
typename GetOpForReduceType<T, rtype>::Op op;
|
typename GetOpForReduceType<T, rtype>::Op op;
|
||||||
COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue,
|
COLOSSAL_BLOCK_REDUCE_IMPL(T, pval, op, lanes, kDefaultValue, rtype);
|
||||||
rtype);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef COLOSSAL_SHFL_FUNCTION
|
#undef COLOSSAL_SHFL_FUNCTION
|
||||||
|
@@ -6,10 +6,6 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace multihead_attn {
|
|
||||||
namespace fused_softmax {
|
|
||||||
namespace scaled_masked_softmax {
|
|
||||||
|
|
||||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||||
float scale_factor);
|
float scale_factor);
|
||||||
|
|
||||||
@@ -17,8 +13,8 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
|||||||
torch::Tensor const& softmax_results,
|
torch::Tensor const& softmax_results,
|
||||||
float scale_factor);
|
float scale_factor);
|
||||||
|
|
||||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||||
int attn_heads);
|
int attn_heads);
|
||||||
|
|
||||||
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
|
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
|
||||||
float scale_factor) {
|
float scale_factor) {
|
||||||
@@ -46,25 +42,13 @@ torch::Tensor bwd(torch::Tensor const& output_grads,
|
|||||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
|
||||||
int attn_heads) {
|
|
||||||
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
|
|
||||||
attn_heads);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // end namespace scaled_masked_softmax
|
|
||||||
} // end namespace fused_softmax
|
|
||||||
} // end namespace multihead_attn
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
|
m.def("forward", &fwd,
|
||||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||||
|
|
||||||
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
|
m.def("backward", &bwd,
|
||||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||||
|
|
||||||
m.def("get_batch_per_block",
|
m.def("get_batch_per_block", &get_batch_per_block,
|
||||||
&multihead_attn::fused_softmax::scaled_masked_softmax::
|
|
||||||
get_batch_per_block,
|
|
||||||
"Return Batch per block size.");
|
"Return Batch per block size.");
|
||||||
}
|
}
|
||||||
|
@@ -6,10 +6,6 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace multihead_attn {
|
|
||||||
namespace fused_softmax {
|
|
||||||
namespace scaled_upper_triang_masked_softmax {
|
|
||||||
|
|
||||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
|
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
|
||||||
|
|
||||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||||
@@ -40,15 +36,9 @@ torch::Tensor bwd(torch::Tensor const& output_grads,
|
|||||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace scaled_upper_triang_masked_softmax
|
|
||||||
} // end namespace fused_softmax
|
|
||||||
} // end namespace multihead_attn
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("forward",
|
m.def("forward", &fwd,
|
||||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
|
|
||||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||||
m.def("backward",
|
m.def("backward", &bwd,
|
||||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
|
|
||||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||||
}
|
}
|
||||||
|
@@ -11,42 +11,33 @@
|
|||||||
#include "block_reduce.h"
|
#include "block_reduce.h"
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
#include "funcs/cast_functor.h"
|
#include "funcs/cast_functor.h"
|
||||||
#include "funcs/op_functor.h"
|
#include "funcs/binary_functor.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::block_reduce;
|
using colossalAI::cuda::utils::block_reduce;
|
||||||
using colossalAI::cuda::utils::ReduceType;
|
using colossalAI::cuda::utils::ReduceType;
|
||||||
using colossalAI::cuda::funcs::TypeConverter;
|
|
||||||
using colossalAI::cuda::funcs::CastFunctor;
|
using colossalAI::cuda::funcs::CastFunctor;
|
||||||
using colossalAI::cuda::funcs::BinaryOpFunctor;
|
using colossalAI::cuda::funcs::BinaryOpFunctor;
|
||||||
using colossalAI::cuda::funcs::BinaryOpType;
|
using colossalAI::cuda::funcs::BinaryOpType;
|
||||||
|
|
||||||
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
|
|
||||||
if (DATA_SIZE == 2) { \
|
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||||
switch (TYPE) { \
|
template <typename T>
|
||||||
case at::ScalarType::Half: { \
|
struct TypeConverter {
|
||||||
using scalar_t = at::Half; \
|
using Type = half2;
|
||||||
__VA_ARGS__; \
|
};
|
||||||
break; \
|
|
||||||
} \
|
#define TYPE_CONVERTER_SPECIALIZATION(FROM, TO) \
|
||||||
case at::ScalarType::BFloat16: { \
|
template <> \
|
||||||
using scalar_t = at::BFloat16; \
|
struct TypeConverter<FROM> { \
|
||||||
__VA_ARGS__; \
|
using Type = TO; \
|
||||||
break; \
|
};
|
||||||
} \
|
|
||||||
default: \
|
TYPE_CONVERTER_SPECIALIZATION(half2, at::Half)
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
TYPE_CONVERTER_SPECIALIZATION(at::Half, half2)
|
||||||
} \
|
TYPE_CONVERTER_SPECIALIZATION(__nv_bfloat162, at::BFloat16)
|
||||||
} else { \
|
TYPE_CONVERTER_SPECIALIZATION(at::BFloat16, __nv_bfloat162)
|
||||||
switch (TYPE) { \
|
|
||||||
case at::ScalarType::Float: { \
|
#undef TYPE_CONVERTER_SPECIALIZATION
|
||||||
using scalar_t = float; \
|
|
||||||
general_##__VA_ARGS__; \
|
|
||||||
break; \
|
|
||||||
} \
|
|
||||||
default: \
|
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
|
|
||||||
// optimized for half and bf16
|
// optimized for half and bf16
|
||||||
template<typename scalar_t, int unroll_factor>
|
template<typename scalar_t, int unroll_factor>
|
||||||
@@ -217,6 +208,36 @@ __global__ void general_fused_add_rms_layernorm_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
|
||||||
|
if (DATA_SIZE == 2) { \
|
||||||
|
switch (TYPE) { \
|
||||||
|
case at::ScalarType::Half: { \
|
||||||
|
using scalar_t = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: { \
|
||||||
|
using scalar_t = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
switch (TYPE) { \
|
||||||
|
case at::ScalarType::Float: { \
|
||||||
|
using scalar_t = float; \
|
||||||
|
general_##__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
|
||||||
void rms_layernorm(
|
void rms_layernorm(
|
||||||
torch::Tensor& out, // [..., hidden_size]
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
@@ -424,3 +445,5 @@ void fused_add_rms_layernorm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT
|
||||||
|
@@ -1,500 +0,0 @@
|
|||||||
/*This code from NVIDIA Megatron:
|
|
||||||
* with minor changes. */
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
|
|
||||||
#include <cfloat>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include "utils/vector_copy_utils.h"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
int log2_ceil(int value) {
|
|
||||||
int log2_value = 0;
|
|
||||||
while ((1 << log2_value) < value) ++log2_value;
|
|
||||||
return log2_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Add {
|
|
||||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Max {
|
|
||||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
|
||||||
return a < b ? b : a;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T
|
|
||||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
|
||||||
unsigned int mask = 0xffffffff) {
|
|
||||||
#if CUDA_VERSION >= 9000
|
|
||||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
|
||||||
#else
|
|
||||||
return __shfl_xor(value, laneMask, width);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
|
||||||
template <typename> class ReduceOp>
|
|
||||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
|
||||||
ReduceOp<acc_t> r;
|
|
||||||
#pragma unroll
|
|
||||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
|
||||||
sum[i] = r(sum[i], b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Extended softmax (from native aten pytorch) with following additional
|
|
||||||
* features 1) input scaling 2) Explicit masking
|
|
||||||
*/
|
|
||||||
template <typename input_t, typename output_t, typename acc_t,
|
|
||||||
int log2_elements>
|
|
||||||
__global__ void scaled_masked_softmax_warp_forward(
|
|
||||||
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
|
|
||||||
int micro_batch_size, int element_count, int pad_batches) {
|
|
||||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
|
||||||
// warp_size of method warp_softmax_forward_kernel.
|
|
||||||
constexpr int next_power_of_two = 1 << log2_elements;
|
|
||||||
constexpr int WARP_SIZE =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
|
||||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
|
||||||
|
|
||||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
|
||||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
|
||||||
int first_batch =
|
|
||||||
(blockDim.y *
|
|
||||||
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
|
|
||||||
threadIdx.y) *
|
|
||||||
WARP_BATCH;
|
|
||||||
int pad_first_batch = 0;
|
|
||||||
if (pad_batches != 1) { // bert style
|
|
||||||
pad_first_batch =
|
|
||||||
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
|
|
||||||
WARP_BATCH;
|
|
||||||
} else { // gpt2 style
|
|
||||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
|
||||||
}
|
|
||||||
|
|
||||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
|
||||||
// many batches have to computed within this WARP.
|
|
||||||
int local_batches = micro_batch_size - first_batch;
|
|
||||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
|
||||||
|
|
||||||
// there might be multiple batches per warp. compute the index within the
|
|
||||||
// batch
|
|
||||||
int local_idx = threadIdx.x;
|
|
||||||
|
|
||||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
|
||||||
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
|
||||||
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
|
||||||
|
|
||||||
// load data from global memory
|
|
||||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
|
||||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
|
||||||
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
|
|
||||||
if (element_index < batch_element_count) {
|
|
||||||
int itr_idx = i * element_count + it * WARP_SIZE;
|
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
|
||||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
if (temp_mask[element] != 1) {
|
|
||||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
|
||||||
} else {
|
|
||||||
elements[i][it + element] = -10000.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute max_value
|
|
||||||
acc_t max_value[WARP_BATCH];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
max_value[i] = elements[i][0];
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
|
||||||
max_value[i] =
|
|
||||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
|
||||||
|
|
||||||
acc_t sum[WARP_BATCH]{0.0f};
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
|
||||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
|
||||||
sum[i] += elements[i][it];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
|
||||||
|
|
||||||
// store result
|
|
||||||
output_t out[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
if (i >= local_batches) break;
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
if (element_index < element_count) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
out[element] = elements[i][it + element] / sum[i];
|
|
||||||
}
|
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
dst + i * element_count + it * WARP_SIZE, out);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t,
|
|
||||||
int log2_elements>
|
|
||||||
__global__ void scaled_masked_softmax_warp_backward(
|
|
||||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
|
||||||
int micro_batch_size, int element_count) {
|
|
||||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
|
||||||
// warp_size of method warp_softmax_backward_kernel.
|
|
||||||
constexpr int next_power_of_two = 1 << log2_elements;
|
|
||||||
constexpr int WARP_SIZE =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
|
||||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
|
||||||
|
|
||||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
|
||||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
|
||||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
|
||||||
|
|
||||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
|
||||||
// many batches have to computed within this WARP.
|
|
||||||
int local_batches = micro_batch_size - first_batch;
|
|
||||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
|
||||||
|
|
||||||
// there might be multiple batches per warp. compute the index within the
|
|
||||||
// batch
|
|
||||||
int local_idx = threadIdx.x;
|
|
||||||
|
|
||||||
// the first element to process by the current thread
|
|
||||||
int thread_offset =
|
|
||||||
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
|
||||||
grad += thread_offset;
|
|
||||||
output += thread_offset;
|
|
||||||
gradInput += thread_offset;
|
|
||||||
|
|
||||||
// load data from global memory
|
|
||||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
|
||||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
|
||||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
|
||||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
if (element_index < batch_element_count) {
|
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
temp_grad, grad + i * element_count + it * WARP_SIZE);
|
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
temp_output, output + i * element_count + it * WARP_SIZE);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
grad_reg[i][it + element] =
|
|
||||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
acc_t sum[WARP_BATCH];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
sum[i] = grad_reg[i][0];
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
|
||||||
sum[i] += grad_reg[i][it];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
|
||||||
|
|
||||||
// store result
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
if (i >= local_batches) break;
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
if (element_index < element_count) {
|
|
||||||
// compute gradients
|
|
||||||
output_t out[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
out[element] =
|
|
||||||
(output_t)(scale * (grad_reg[i][it + element] -
|
|
||||||
output_reg[i][it + element] * sum[i]));
|
|
||||||
}
|
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
gradInput + i * element_count + it * WARP_SIZE, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // end of anonymous namespace
|
|
||||||
|
|
||||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
|
||||||
int attn_heads) {
|
|
||||||
int log2_elements = log2_ceil(key_seq_len);
|
|
||||||
const int next_power_of_two = 1 << log2_elements;
|
|
||||||
|
|
||||||
int warp_size =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
|
|
||||||
constexpr int threads_per_block = 128;
|
|
||||||
int warps_per_block = (threads_per_block / warp_size);
|
|
||||||
int batches_per_block = warps_per_block * batches_per_warp;
|
|
||||||
|
|
||||||
return batches_per_block;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t>
|
|
||||||
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
|
|
||||||
const uint8_t *mask,
|
|
||||||
const input_t scale,
|
|
||||||
int query_seq_len, int key_seq_len,
|
|
||||||
int batches, int attn_heads,
|
|
||||||
int pad_batches) {
|
|
||||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
|
||||||
if (key_seq_len == 0) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
int log2_elements = log2_ceil(key_seq_len);
|
|
||||||
const int next_power_of_two = 1 << log2_elements;
|
|
||||||
int batch_count = batches * attn_heads * query_seq_len;
|
|
||||||
|
|
||||||
// This value must match the WARP_SIZE constexpr value computed inside
|
|
||||||
// softmax_warp_forward.
|
|
||||||
int warp_size =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
|
|
||||||
// This value must match the WARP_BATCH constexpr value computed inside
|
|
||||||
// softmax_warp_forward.
|
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
|
|
||||||
// use 128 threads per block to maximimize gpu utilization
|
|
||||||
constexpr int threads_per_block = 128;
|
|
||||||
|
|
||||||
int warps_per_block = (threads_per_block / warp_size);
|
|
||||||
int batches_per_block = warps_per_block * batches_per_warp;
|
|
||||||
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
|
|
||||||
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
|
|
||||||
dim3 threads(warp_size, warps_per_block, 1);
|
|
||||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
|
||||||
switch (log2_elements) {
|
|
||||||
case 0: // 1
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 1: // 2
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 2: // 4
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 3: // 8
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 4: // 16
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 5: // 32
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 6: // 64
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 7: // 128
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 8: // 256
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 9: // 512
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 10: // 1024
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
case 11: // 2048
|
|
||||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t>
|
|
||||||
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
|
|
||||||
input_t *grad,
|
|
||||||
const input_t *output,
|
|
||||||
const acc_t scale,
|
|
||||||
int query_seq_len, int key_seq_len,
|
|
||||||
int batches, int attn_heads) {
|
|
||||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
|
||||||
if (key_seq_len == 0) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
int log2_elements = log2_ceil(key_seq_len);
|
|
||||||
const int next_power_of_two = 1 << log2_elements;
|
|
||||||
int batch_count = batches * attn_heads * query_seq_len;
|
|
||||||
|
|
||||||
// This value must match the WARP_SIZE constexpr value computed inside
|
|
||||||
// softmax_warp_backward.
|
|
||||||
int warp_size =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
|
|
||||||
// This value must match the WARP_BATCH constexpr value computed inside
|
|
||||||
// softmax_warp_backward.
|
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
|
|
||||||
// use 128 threads per block to maximimize gpu utilization
|
|
||||||
constexpr int threads_per_block = 128;
|
|
||||||
|
|
||||||
int warps_per_block = (threads_per_block / warp_size);
|
|
||||||
int batches_per_block = warps_per_block * batches_per_warp;
|
|
||||||
int blocks = batch_count / batches_per_block;
|
|
||||||
dim3 threads(warp_size, warps_per_block, 1);
|
|
||||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
|
||||||
switch (log2_elements) {
|
|
||||||
case 0: // 1
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 1: // 2
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 2: // 4
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 3: // 8
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 4: // 16
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 5: // 32
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 6: // 64
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 7: // 128
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 8: // 256
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 9: // 512
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 10: // 1024
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
case 11: // 2048
|
|
||||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -9,16 +9,462 @@
|
|||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "scaled_masked_softmax.h"
|
#include <assert.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
#include "utils/vec_copy.h"
|
||||||
|
#include "include/block_reduce.h"
|
||||||
|
#include "funcs/unary_functor.h"
|
||||||
|
|
||||||
namespace multihead_attn {
|
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
||||||
namespace fused_softmax {
|
using colossalAI::cuda::funcs::UnaryOpType;
|
||||||
namespace scaled_masked_softmax {
|
using colossalAI::cuda::utils::warp_reduce;
|
||||||
|
using colossalAI::cuda::utils::ReduceType;
|
||||||
|
|
||||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
|
||||||
int attn_heads) {
|
/*
|
||||||
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
* Extended softmax (from native aten pytorch) with following additional
|
||||||
|
* features 1) input scaling 2) Explicit masking
|
||||||
|
*/
|
||||||
|
template <typename input_t, typename output_t, typename acc_t,
|
||||||
|
int log2_elements>
|
||||||
|
__global__ void scaled_masked_softmax_warp_forward(
|
||||||
|
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
|
||||||
|
int micro_batch_size, int element_count, int pad_batches) {
|
||||||
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
|
// warp_size of method warp_softmax_forward_kernel.
|
||||||
|
constexpr int next_power_of_two = 1 << log2_elements;
|
||||||
|
constexpr int WARP_SIZE =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||||
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
|
|
||||||
|
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||||
|
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||||
|
int first_batch =
|
||||||
|
(blockDim.y *
|
||||||
|
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
|
||||||
|
threadIdx.y) *
|
||||||
|
WARP_BATCH;
|
||||||
|
int pad_first_batch = 0;
|
||||||
|
if (pad_batches != 1) { // bert style
|
||||||
|
pad_first_batch =
|
||||||
|
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
|
||||||
|
WARP_BATCH;
|
||||||
|
} else { // gpt2 style
|
||||||
|
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||||
|
}
|
||||||
|
|
||||||
|
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||||
|
// many batches have to computed within this WARP.
|
||||||
|
int local_batches = micro_batch_size - first_batch;
|
||||||
|
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||||
|
|
||||||
|
// there might be multiple batches per warp. compute the index within the
|
||||||
|
// batch
|
||||||
|
int local_idx = threadIdx.x;
|
||||||
|
|
||||||
|
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
|
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
|
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
|
|
||||||
|
// load data from global memory
|
||||||
|
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||||
|
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||||
|
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
|
||||||
|
if (element_index < batch_element_count) {
|
||||||
|
int itr_idx = i * element_count + it * WARP_SIZE;
|
||||||
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
||||||
|
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
if (temp_mask[element] != 1) {
|
||||||
|
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||||
|
} else {
|
||||||
|
elements[i][it + element] = -10000.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute max_value
|
||||||
|
acc_t max_value[WARP_BATCH];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
max_value[i] = elements[i][0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||||
|
max_value[i] =
|
||||||
|
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warp_reduce<acc_t,ReduceType::kMax,WARP_BATCH,WARP_SIZE>(max_value);
|
||||||
|
|
||||||
|
acc_t sum[WARP_BATCH]{0.0f};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||||
|
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||||
|
sum[i] += elements[i][it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warp_reduce<acc_t,ReduceType::kSum,WARP_BATCH,WARP_SIZE>(sum);
|
||||||
|
|
||||||
|
// store result
|
||||||
|
output_t out[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
if (i >= local_batches) break;
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
if (element_index < element_count) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
out[element] = elements[i][it + element] / sum[i];
|
||||||
|
}
|
||||||
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
dst + i * element_count + it * WARP_SIZE, out);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t,
|
||||||
|
int log2_elements>
|
||||||
|
__global__ void scaled_masked_softmax_warp_backward(
|
||||||
|
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||||
|
int micro_batch_size, int element_count) {
|
||||||
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
|
// warp_size of method warp_softmax_backward_kernel.
|
||||||
|
constexpr int next_power_of_two = 1 << log2_elements;
|
||||||
|
constexpr int WARP_SIZE =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||||
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
|
|
||||||
|
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||||
|
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||||
|
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||||
|
|
||||||
|
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||||
|
// many batches have to computed within this WARP.
|
||||||
|
int local_batches = micro_batch_size - first_batch;
|
||||||
|
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||||
|
|
||||||
|
// there might be multiple batches per warp. compute the index within the
|
||||||
|
// batch
|
||||||
|
int local_idx = threadIdx.x;
|
||||||
|
|
||||||
|
// the first element to process by the current thread
|
||||||
|
int thread_offset =
|
||||||
|
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
|
grad += thread_offset;
|
||||||
|
output += thread_offset;
|
||||||
|
gradInput += thread_offset;
|
||||||
|
|
||||||
|
// load data from global memory
|
||||||
|
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||||
|
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||||
|
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||||
|
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
if (element_index < batch_element_count) {
|
||||||
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||||
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
temp_output, output + i * element_count + it * WARP_SIZE);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
grad_reg[i][it + element] =
|
||||||
|
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
acc_t sum[WARP_BATCH];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
sum[i] = grad_reg[i][0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||||
|
sum[i] += grad_reg[i][it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warp_reduce<acc_t,ReduceType::kSum,WARP_BATCH,WARP_SIZE>(sum);
|
||||||
|
|
||||||
|
// store result
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
if (i >= local_batches) break;
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
if (element_index < element_count) {
|
||||||
|
// compute gradients
|
||||||
|
output_t out[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
out[element] =
|
||||||
|
(output_t)(scale * (grad_reg[i][it + element] -
|
||||||
|
output_reg[i][it + element] * sum[i]));
|
||||||
|
}
|
||||||
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
gradInput + i * element_count + it * WARP_SIZE, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||||
|
int attn_heads) {
|
||||||
|
int log2_elements = UnaryOpFunctor<int, int, UnaryOpType::kLog2Ceil>()(key_seq_len);
|
||||||
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
|
|
||||||
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
int warps_per_block = (threads_per_block / warp_size);
|
||||||
|
int batches_per_block = warps_per_block * batches_per_warp;
|
||||||
|
|
||||||
|
return batches_per_block;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
|
||||||
|
const uint8_t *mask,
|
||||||
|
const input_t scale,
|
||||||
|
int query_seq_len, int key_seq_len,
|
||||||
|
int batches, int attn_heads,
|
||||||
|
int pad_batches) {
|
||||||
|
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||||
|
if (key_seq_len == 0) {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
int log2_elements = UnaryOpFunctor<int, int, UnaryOpType::kLog2Ceil>()(key_seq_len);
|
||||||
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
|
int batch_count = batches * attn_heads * query_seq_len;
|
||||||
|
|
||||||
|
// This value must match the WARP_SIZE constexpr value computed inside
|
||||||
|
// softmax_warp_forward.
|
||||||
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
|
||||||
|
// This value must match the WARP_BATCH constexpr value computed inside
|
||||||
|
// softmax_warp_forward.
|
||||||
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
|
// use 128 threads per block to maximimize gpu utilization
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
|
||||||
|
int warps_per_block = (threads_per_block / warp_size);
|
||||||
|
int batches_per_block = warps_per_block * batches_per_warp;
|
||||||
|
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
|
||||||
|
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
|
||||||
|
dim3 threads(warp_size, warps_per_block, 1);
|
||||||
|
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||||
|
switch (log2_elements) {
|
||||||
|
case 0: // 1
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 1: // 2
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 2: // 4
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 3: // 8
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 4: // 16
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 5: // 32
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 6: // 64
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 7: // 128
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 8: // 256
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 9: // 512
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 10: // 1024
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
case 11: // 2048
|
||||||
|
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
|
||||||
|
input_t *grad,
|
||||||
|
const input_t *output,
|
||||||
|
const acc_t scale,
|
||||||
|
int query_seq_len, int key_seq_len,
|
||||||
|
int batches, int attn_heads) {
|
||||||
|
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||||
|
if (key_seq_len == 0) {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
int log2_elements = UnaryOpFunctor<int, int, UnaryOpType::kLog2Ceil>()(key_seq_len);
|
||||||
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
|
int batch_count = batches * attn_heads * query_seq_len;
|
||||||
|
|
||||||
|
// This value must match the WARP_SIZE constexpr value computed inside
|
||||||
|
// softmax_warp_backward.
|
||||||
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
|
||||||
|
// This value must match the WARP_BATCH constexpr value computed inside
|
||||||
|
// softmax_warp_backward.
|
||||||
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
|
// use 128 threads per block to maximimize gpu utilization
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
|
||||||
|
int warps_per_block = (threads_per_block / warp_size);
|
||||||
|
int batches_per_block = warps_per_block * batches_per_warp;
|
||||||
|
int blocks = batch_count / batches_per_block;
|
||||||
|
dim3 threads(warp_size, warps_per_block, 1);
|
||||||
|
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||||
|
switch (log2_elements) {
|
||||||
|
case 0: // 1
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 1: // 2
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 2: // 4
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 3: // 8
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 4: // 16
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 5: // 32
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 6: // 64
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 7: // 128
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 8: // 256
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 9: // 512
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 10: // 1024
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
case 11: // 2048
|
||||||
|
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||||
@@ -84,6 +530,3 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
|||||||
// backward pass is completely in-place
|
// backward pass is completely in-place
|
||||||
return output_grads;
|
return output_grads;
|
||||||
}
|
}
|
||||||
} // namespace scaled_masked_softmax
|
|
||||||
} // namespace fused_softmax
|
|
||||||
} // namespace multihead_attn
|
|
||||||
|
@@ -1,538 +0,0 @@
|
|||||||
/*This code from NVIDIA Megatron:
|
|
||||||
* with minor changes. */
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
#include <c10/macros/Macros.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <cfloat>
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include "utils/vector_copy_utils.h"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
int log2_ceil(int value) {
|
|
||||||
int log2_value = 0;
|
|
||||||
while ((1 << log2_value) < value) ++log2_value;
|
|
||||||
return log2_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Add {
|
|
||||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Max {
|
|
||||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
|
||||||
return a < b ? b : a;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T
|
|
||||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
|
||||||
unsigned int mask = 0xffffffff) {
|
|
||||||
#if CUDA_VERSION >= 9000
|
|
||||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
|
||||||
#else
|
|
||||||
return __shfl_xor(value, laneMask, width);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
|
||||||
template <typename> class ReduceOp>
|
|
||||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
|
||||||
ReduceOp<acc_t> r;
|
|
||||||
#pragma unroll
|
|
||||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
|
||||||
sum[i] = r(sum[i], b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Extended softmax (from native aten pytorch) with following additional
|
|
||||||
* features 1) input scaling 2) Implicit time (diagonal masking)
|
|
||||||
*/
|
|
||||||
template <typename input_t, typename output_t, typename acc_t,
|
|
||||||
int log2_elements>
|
|
||||||
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
|
||||||
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size,
|
|
||||||
int stride, int element_count) {
|
|
||||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
|
||||||
// warp_size of method warp_softmax_forward_kernel.
|
|
||||||
constexpr int next_power_of_two = 1 << log2_elements;
|
|
||||||
constexpr int WARP_SIZE =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
|
||||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
|
||||||
|
|
||||||
int first_batch =
|
|
||||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
|
||||||
blockIdx.x;
|
|
||||||
int local_seq = blockIdx.x + 1;
|
|
||||||
int warp_iteration_limit =
|
|
||||||
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
|
|
||||||
|
|
||||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
|
||||||
// many batches have to computed within this WARP.
|
|
||||||
int local_batches = micro_batch_size - first_batch;
|
|
||||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
|
||||||
|
|
||||||
// there might be multiple batches per warp. compute the index within the
|
|
||||||
// batch
|
|
||||||
int local_idx = threadIdx.x;
|
|
||||||
|
|
||||||
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
|
||||||
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
|
||||||
|
|
||||||
// load data from global memory
|
|
||||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
|
||||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
|
|
||||||
if (element_index < batch_element_count) {
|
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
temp_data, src + i * element_count * stride + it * WARP_SIZE);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
if ((element_index + element) < batch_element_count) {
|
|
||||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
|
||||||
} else {
|
|
||||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute max_value
|
|
||||||
acc_t max_value[WARP_BATCH];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
max_value[i] = elements[i][0];
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
|
||||||
max_value[i] =
|
|
||||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
|
||||||
|
|
||||||
acc_t sum[WARP_BATCH]{0.0f};
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
|
||||||
if (it < warp_iteration_limit) {
|
|
||||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
|
||||||
sum[i] += elements[i][it];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
|
||||||
|
|
||||||
// store result
|
|
||||||
output_t out[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
if (i >= local_batches) break;
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
|
|
||||||
if (element_index < local_seq) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
if (element_index + element < local_seq) {
|
|
||||||
out[element] = elements[i][it + element] / sum[i];
|
|
||||||
} else {
|
|
||||||
out[element] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
dst + i * element_count * stride + it * WARP_SIZE, out);
|
|
||||||
} else if (element_index < element_count) {
|
|
||||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
dst + i * element_count * stride + it * WARP_SIZE);
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t,
|
|
||||||
int log2_elements>
|
|
||||||
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
|
||||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
|
||||||
int micro_batch_size, int stride, int element_count) {
|
|
||||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
|
||||||
// warp_size of method warp_softmax_backward_kernel.
|
|
||||||
constexpr int next_power_of_two = 1 << log2_elements;
|
|
||||||
constexpr int WARP_SIZE =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
|
||||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
|
||||||
|
|
||||||
int first_batch =
|
|
||||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
|
||||||
blockIdx.x;
|
|
||||||
int local_seq = blockIdx.x + 1;
|
|
||||||
|
|
||||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
|
||||||
// many batches have to computed within this WARP.
|
|
||||||
int local_batches = micro_batch_size - first_batch;
|
|
||||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
|
||||||
|
|
||||||
// there might be multiple batches per warp. compute the index within the
|
|
||||||
// batch
|
|
||||||
int local_idx = threadIdx.x;
|
|
||||||
|
|
||||||
// the first element to process by the current thread
|
|
||||||
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
|
||||||
grad += thread_offset;
|
|
||||||
output += thread_offset;
|
|
||||||
gradInput += thread_offset;
|
|
||||||
|
|
||||||
// load data from global memory
|
|
||||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
|
||||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
|
||||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
|
||||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
if (element_index < batch_element_count) {
|
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
|
||||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
if (element_index + element < batch_element_count) {
|
|
||||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
if (element_index + element < batch_element_count) {
|
|
||||||
grad_reg[i][it + element] =
|
|
||||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
acc_t sum[WARP_BATCH];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
sum[i] = grad_reg[i][0];
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
|
||||||
sum[i] += grad_reg[i][it];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
|
||||||
|
|
||||||
// store result
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
|
||||||
if (i >= local_batches) break;
|
|
||||||
#pragma unroll
|
|
||||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
|
||||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
|
||||||
if (element_index < element_count) {
|
|
||||||
// compute gradients
|
|
||||||
output_t out[ELEMENTS_PER_LDG_STG];
|
|
||||||
#pragma unroll
|
|
||||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
|
||||||
out[element] =
|
|
||||||
(output_t)(scale * (grad_reg[i][it + element] -
|
|
||||||
output_reg[i][it + element] * sum[i]));
|
|
||||||
}
|
|
||||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
|
||||||
gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // end of anonymous namespace
|
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t>
|
|
||||||
void dispatch_scaled_upper_triang_masked_softmax_forward(
|
|
||||||
output_t *dst, const input_t *src, const input_t scale,
|
|
||||||
int softmax_elements, int softmax_elements_stride, int attn_batches) {
|
|
||||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
|
||||||
if (softmax_elements == 0) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
int log2_elements = log2_ceil(softmax_elements);
|
|
||||||
const int next_power_of_two = 1 << log2_elements;
|
|
||||||
int seq_len = softmax_elements;
|
|
||||||
int batch_count = attn_batches * seq_len;
|
|
||||||
|
|
||||||
// This value must match the WARP_SIZE constexpr value computed inside
|
|
||||||
// softmax_warp_forward.
|
|
||||||
int warp_size =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
|
|
||||||
// This value must match the WARP_BATCH constexpr value computed inside
|
|
||||||
// softmax_warp_forward.
|
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
|
|
||||||
// use 128 threads per block to maximimize gpu utilization
|
|
||||||
constexpr int threads_per_block = 128;
|
|
||||||
|
|
||||||
int warps_per_block = (threads_per_block / warp_size);
|
|
||||||
int batches_per_block = warps_per_block * batches_per_warp;
|
|
||||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
|
||||||
|
|
||||||
int blocks_per_seq = attn_batches / batches_per_block;
|
|
||||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
|
||||||
dim3 threads(warp_size, warps_per_block, 1);
|
|
||||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
|
||||||
switch (log2_elements) {
|
|
||||||
case 0: // 1
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 0>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 1: // 2
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 1>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 2: // 4
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 2>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 3: // 8
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 3>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 4: // 16
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 4>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 5: // 32
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 5>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 6: // 64
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 6>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 7: // 128
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 7>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 8: // 256
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 8>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 9: // 512
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 9>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 10: // 1024
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 10>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
case 11: // 2048
|
|
||||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
|
||||||
acc_t, 11>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
dst, src, scale, batch_count, softmax_elements_stride,
|
|
||||||
softmax_elements);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename input_t, typename output_t, typename acc_t>
|
|
||||||
void dispatch_scaled_upper_triang_masked_softmax_backward(
|
|
||||||
output_t *grad_input, input_t *grad, const input_t *output,
|
|
||||||
const acc_t scale, int softmax_elements, int softmax_elements_stride,
|
|
||||||
int attn_batches) {
|
|
||||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
|
||||||
if (softmax_elements == 0) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
int log2_elements = log2_ceil(softmax_elements);
|
|
||||||
const int next_power_of_two = 1 << log2_elements;
|
|
||||||
int seq_len = softmax_elements;
|
|
||||||
int batch_count = attn_batches * seq_len;
|
|
||||||
|
|
||||||
// This value must match the WARP_SIZE constexpr value computed inside
|
|
||||||
// softmax_warp_backward.
|
|
||||||
int warp_size =
|
|
||||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
|
||||||
|
|
||||||
// This value must match the WARP_BATCH constexpr value computed inside
|
|
||||||
// softmax_warp_backward.
|
|
||||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
|
||||||
|
|
||||||
// use 128 threads per block to maximimize gpu utilization
|
|
||||||
constexpr int threads_per_block = 128;
|
|
||||||
|
|
||||||
int warps_per_block = (threads_per_block / warp_size);
|
|
||||||
int batches_per_block = warps_per_block * batches_per_warp;
|
|
||||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
|
||||||
|
|
||||||
int blocks_per_seq = attn_batches / batches_per_block;
|
|
||||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
|
||||||
dim3 threads(warp_size, warps_per_block, 1);
|
|
||||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
|
||||||
switch (log2_elements) {
|
|
||||||
case 0: // 1
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 0>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 1: // 2
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 1>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 2: // 4
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 2>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 3: // 8
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 3>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 4: // 16
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 4>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 5: // 32
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 5>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 6: // 64
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 6>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 7: // 128
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 7>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 8: // 256
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 8>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 9: // 512
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 9>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 10: // 1024
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 10>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
case 11: // 2048
|
|
||||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
|
||||||
acc_t, 11>
|
|
||||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
|
||||||
grad_input, grad, output, scale, batch_count,
|
|
||||||
softmax_elements_stride, softmax_elements);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@@ -8,13 +8,502 @@
|
|||||||
#include <cuda_profiler_api.h>
|
#include <cuda_profiler_api.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
#include <assert.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "scaled_upper_triang_masked_softmax.h"
|
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
|
#include "utils/vec_copy.h"
|
||||||
|
#include "include/block_reduce.h"
|
||||||
|
#include "funcs/unary_functor.h"
|
||||||
|
|
||||||
|
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
||||||
|
using colossalAI::cuda::funcs::UnaryOpType;
|
||||||
|
using colossalAI::cuda::utils::warp_reduce;
|
||||||
|
using colossalAI::cuda::utils::ReduceType;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Extended softmax (from native aten pytorch) with following additional
|
||||||
|
* features 1) input scaling 2) Implicit time (diagonal masking)
|
||||||
|
*/
|
||||||
|
template <typename input_t, typename output_t, typename acc_t,
|
||||||
|
int log2_elements>
|
||||||
|
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||||
|
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size,
|
||||||
|
int stride, int element_count) {
|
||||||
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
|
// warp_size of method warp_softmax_forward_kernel.
|
||||||
|
constexpr int next_power_of_two = 1 << log2_elements;
|
||||||
|
constexpr int WARP_SIZE =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||||
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
|
|
||||||
|
int first_batch =
|
||||||
|
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||||
|
blockIdx.x;
|
||||||
|
int local_seq = blockIdx.x + 1;
|
||||||
|
int warp_iteration_limit =
|
||||||
|
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
|
||||||
|
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||||
|
// many batches have to computed within this WARP.
|
||||||
|
int local_batches = micro_batch_size - first_batch;
|
||||||
|
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||||
|
|
||||||
|
// there might be multiple batches per warp. compute the index within the
|
||||||
|
// batch
|
||||||
|
int local_idx = threadIdx.x;
|
||||||
|
|
||||||
|
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
|
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
|
|
||||||
|
// load data from global memory
|
||||||
|
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||||
|
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
|
||||||
|
if (element_index < batch_element_count) {
|
||||||
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
temp_data, src + i * element_count * stride + it * WARP_SIZE);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
if ((element_index + element) < batch_element_count) {
|
||||||
|
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||||
|
} else {
|
||||||
|
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute max_value
|
||||||
|
acc_t max_value[WARP_BATCH];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
max_value[i] = elements[i][0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||||
|
max_value[i] =
|
||||||
|
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warp_reduce<acc_t,ReduceType::kMax,WARP_BATCH,WARP_SIZE>(max_value);
|
||||||
|
|
||||||
|
acc_t sum[WARP_BATCH]{0.0f};
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||||
|
if (it < warp_iteration_limit) {
|
||||||
|
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||||
|
sum[i] += elements[i][it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warp_reduce<acc_t,ReduceType::kSum,WARP_BATCH,WARP_SIZE>(sum);
|
||||||
|
|
||||||
|
|
||||||
|
// store result
|
||||||
|
output_t out[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
if (i >= local_batches) break;
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
|
||||||
|
if (element_index < local_seq) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
if (element_index + element < local_seq) {
|
||||||
|
out[element] = elements[i][it + element] / sum[i];
|
||||||
|
} else {
|
||||||
|
out[element] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
dst + i * element_count * stride + it * WARP_SIZE, out);
|
||||||
|
} else if (element_index < element_count) {
|
||||||
|
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
dst + i * element_count * stride + it * WARP_SIZE);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t,
|
||||||
|
int log2_elements>
|
||||||
|
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||||
|
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||||
|
int micro_batch_size, int stride, int element_count) {
|
||||||
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
|
// warp_size of method warp_softmax_backward_kernel.
|
||||||
|
constexpr int next_power_of_two = 1 << log2_elements;
|
||||||
|
constexpr int WARP_SIZE =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||||
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
|
|
||||||
|
int first_batch =
|
||||||
|
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||||
|
blockIdx.x;
|
||||||
|
int local_seq = blockIdx.x + 1;
|
||||||
|
|
||||||
|
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||||
|
// many batches have to computed within this WARP.
|
||||||
|
int local_batches = micro_batch_size - first_batch;
|
||||||
|
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||||
|
|
||||||
|
// there might be multiple batches per warp. compute the index within the
|
||||||
|
// batch
|
||||||
|
int local_idx = threadIdx.x;
|
||||||
|
|
||||||
|
// the first element to process by the current thread
|
||||||
|
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||||
|
grad += thread_offset;
|
||||||
|
output += thread_offset;
|
||||||
|
gradInput += thread_offset;
|
||||||
|
|
||||||
|
// load data from global memory
|
||||||
|
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||||
|
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||||
|
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||||
|
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
if (element_index < batch_element_count) {
|
||||||
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
||||||
|
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
if (element_index + element < batch_element_count) {
|
||||||
|
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
if (element_index + element < batch_element_count) {
|
||||||
|
grad_reg[i][it + element] =
|
||||||
|
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
acc_t sum[WARP_BATCH];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
sum[i] = grad_reg[i][0];
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||||
|
sum[i] += grad_reg[i][it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warp_reduce<acc_t,ReduceType::kSum,WARP_BATCH,WARP_SIZE>(sum);
|
||||||
|
|
||||||
|
// store result
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||||
|
if (i >= local_batches) break;
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||||
|
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||||
|
if (element_index < element_count) {
|
||||||
|
// compute gradients
|
||||||
|
output_t out[ELEMENTS_PER_LDG_STG];
|
||||||
|
#pragma unroll
|
||||||
|
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||||
|
out[element] =
|
||||||
|
(output_t)(scale * (grad_reg[i][it + element] -
|
||||||
|
output_reg[i][it + element] * sum[i]));
|
||||||
|
}
|
||||||
|
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||||
|
gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
void dispatch_scaled_upper_triang_masked_softmax_forward(
|
||||||
|
output_t *dst, const input_t *src, const input_t scale,
|
||||||
|
int softmax_elements, int softmax_elements_stride, int attn_batches) {
|
||||||
|
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||||
|
if (softmax_elements == 0) {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
int log2_elements = UnaryOpFunctor<int, int, UnaryOpType::kLog2Ceil>()(softmax_elements);
|
||||||
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
|
int seq_len = softmax_elements;
|
||||||
|
int batch_count = attn_batches * seq_len;
|
||||||
|
|
||||||
|
// This value must match the WARP_SIZE constexpr value computed inside
|
||||||
|
// softmax_warp_forward.
|
||||||
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
|
||||||
|
// This value must match the WARP_BATCH constexpr value computed inside
|
||||||
|
// softmax_warp_forward.
|
||||||
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
|
// use 128 threads per block to maximimize gpu utilization
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
|
||||||
|
int warps_per_block = (threads_per_block / warp_size);
|
||||||
|
int batches_per_block = warps_per_block * batches_per_warp;
|
||||||
|
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||||
|
|
||||||
|
int blocks_per_seq = attn_batches / batches_per_block;
|
||||||
|
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||||
|
dim3 threads(warp_size, warps_per_block, 1);
|
||||||
|
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||||
|
switch (log2_elements) {
|
||||||
|
case 0: // 1
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 0>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 1: // 2
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 1>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 2: // 4
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 2>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 3: // 8
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 3>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 4: // 16
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 4>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 5: // 32
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 5>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 6: // 64
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 6>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 7: // 128
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 7>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 8: // 256
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 8>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 9: // 512
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 9>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 10: // 1024
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 10>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
case 11: // 2048
|
||||||
|
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||||
|
acc_t, 11>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
dst, src, scale, batch_count, softmax_elements_stride,
|
||||||
|
softmax_elements);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
void dispatch_scaled_upper_triang_masked_softmax_backward(
|
||||||
|
output_t *grad_input, input_t *grad, const input_t *output,
|
||||||
|
const acc_t scale, int softmax_elements, int softmax_elements_stride,
|
||||||
|
int attn_batches) {
|
||||||
|
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||||
|
if (softmax_elements == 0) {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
int log2_elements = UnaryOpFunctor<int, int, UnaryOpType::kLog2Ceil>()(softmax_elements);
|
||||||
|
const int next_power_of_two = 1 << log2_elements;
|
||||||
|
int seq_len = softmax_elements;
|
||||||
|
int batch_count = attn_batches * seq_len;
|
||||||
|
|
||||||
|
// This value must match the WARP_SIZE constexpr value computed inside
|
||||||
|
// softmax_warp_backward.
|
||||||
|
int warp_size =
|
||||||
|
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||||
|
|
||||||
|
// This value must match the WARP_BATCH constexpr value computed inside
|
||||||
|
// softmax_warp_backward.
|
||||||
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||||
|
|
||||||
|
// use 128 threads per block to maximimize gpu utilization
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
|
||||||
|
int warps_per_block = (threads_per_block / warp_size);
|
||||||
|
int batches_per_block = warps_per_block * batches_per_warp;
|
||||||
|
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||||
|
|
||||||
|
int blocks_per_seq = attn_batches / batches_per_block;
|
||||||
|
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||||
|
dim3 threads(warp_size, warps_per_block, 1);
|
||||||
|
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||||
|
switch (log2_elements) {
|
||||||
|
case 0: // 1
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 0>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 1: // 2
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 1>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 2: // 4
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 2>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 3: // 8
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 3>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 4: // 16
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 4>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 5: // 32
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 5>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 6: // 64
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 6>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 7: // 128
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 7>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 8: // 256
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 8>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 9: // 512
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 9>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 10: // 1024
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 10>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
case 11: // 2048
|
||||||
|
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||||
|
acc_t, 11>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
|
grad_input, grad, output, scale, batch_count,
|
||||||
|
softmax_elements_stride, softmax_elements);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace multihead_attn {
|
|
||||||
namespace fused_softmax {
|
|
||||||
namespace scaled_upper_triang_masked_softmax {
|
|
||||||
|
|
||||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
|
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
|
||||||
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||||
@@ -70,6 +559,3 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
|||||||
// backward pass is completely in-place
|
// backward pass is completely in-place
|
||||||
return output_grads;
|
return output_grads;
|
||||||
}
|
}
|
||||||
} // namespace scaled_upper_triang_masked_softmax
|
|
||||||
} // namespace fused_softmax
|
|
||||||
} // namespace multihead_attn
|
|
||||||
|
@@ -13,70 +13,27 @@ namespace utils {
|
|||||||
template <typename T, int VecSize>
|
template <typename T, int VecSize>
|
||||||
struct VecTypeTrait {};
|
struct VecTypeTrait {};
|
||||||
|
|
||||||
template <typename T>
|
#define VEC_TYPE_TRAITS_SPECIALIZATION(T, VEC_SIZE, VECT, ARGS...) \
|
||||||
struct VecTypeTrait<T, 1> {
|
template <ARGS> \
|
||||||
using Type = T;
|
struct VecTypeTrait<T, VEC_SIZE> { \
|
||||||
};
|
using Type = VECT; \
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
|
||||||
struct VecTypeTrait<c10::BFloat16, 2> {
|
VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 2, float)
|
||||||
using Type = float;
|
VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 4, float2)
|
||||||
};
|
VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 8, float4)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 2, float)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 4, float2)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 8, float4)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float4)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
|
||||||
|
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
|
||||||
|
|
||||||
template <>
|
#undef VEC_TYPE_TRAITS_SPECIALIZATION
|
||||||
struct VecTypeTrait<c10::BFloat16, 4> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<c10::BFloat16, 8> {
|
|
||||||
using Type = float4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<c10::Half, 2> {
|
|
||||||
using Type = float;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<c10::Half, 4> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<c10::Half, 8> {
|
|
||||||
using Type = float4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<float, 2> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<float, 4> {
|
|
||||||
using Type = float4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<float, 8> {
|
|
||||||
using Type = float4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<uint8_t, 2> {
|
|
||||||
using Type = half;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<uint8_t, 4> {
|
|
||||||
using Type = half2;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct VecTypeTrait<uint8_t, 8> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace cuda
|
} // namespace cuda
|
||||||
|
Reference in New Issue
Block a user