mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-03 12:49:42 +00:00
[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)
* refactor compilation mechanism and unified multi hw * fix file path bug * add init.py to make pybind a module to avoid relative path error caused by softlink * delete duplicated micros * fix micros bug in gcc
This commit is contained in:
parent
04863a9b14
commit
279300dc5f
@ -35,7 +35,7 @@ from transformers.utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN
|
from colossalai.kernel.extensions.pybind.flash_attention import HAS_FLASH_ATTN
|
||||||
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
|
||||||
from colossalai.moe.layers import SparseMLP
|
from colossalai.moe.layers import SparseMLP
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
|
from .pybind.cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
|
||||||
from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension
|
from .pybind.flash_attention import (
|
||||||
from .inference import InferenceOpsCudaExtension
|
FlashAttentionDaoCudaExtension,
|
||||||
from .layernorm import LayerNormCudaExtension
|
FlashAttentionNpuExtension,
|
||||||
from .moe import MoeCudaExtension
|
FlashAttentionSdpaCudaExtension,
|
||||||
from .optimizer import FusedOptimizerCudaExtension
|
)
|
||||||
from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
from .pybind.inference import InferenceOpsCudaExtension
|
||||||
|
from .pybind.layernorm import LayerNormCudaExtension
|
||||||
|
from .pybind.moe import MoeCudaExtension
|
||||||
|
from .pybind.optimizer import FusedOptimizerCudaExtension
|
||||||
|
from .pybind.softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
||||||
|
|
||||||
ALL_EXTENSIONS = [
|
ALL_EXTENSIONS = [
|
||||||
CpuAdamArmExtension,
|
CpuAdamArmExtension,
|
||||||
|
@ -25,6 +25,9 @@ class _CppExtension(_Extension):
|
|||||||
def csrc_abs_path(self, path):
|
def csrc_abs_path(self, path):
|
||||||
return os.path.join(self.relative_to_abs_path("csrc"), path)
|
return os.path.join(self.relative_to_abs_path("csrc"), path)
|
||||||
|
|
||||||
|
def pybind_abs_path(self, path):
|
||||||
|
return os.path.join(self.relative_to_abs_path("pybind"), path)
|
||||||
|
|
||||||
def relative_to_abs_path(self, code_path: str) -> str:
|
def relative_to_abs_path(self, code_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
This function takes in a path relative to the colossalai root directory and return the absolute path.
|
This function takes in a path relative to the colossalai root directory and return the absolute path.
|
||||||
@ -116,6 +119,7 @@ class _CppExtension(_Extension):
|
|||||||
"""
|
"""
|
||||||
This function should return a list of include files for extensions.
|
This function should return a list of include files for extensions.
|
||||||
"""
|
"""
|
||||||
|
return [self.csrc_abs_path("")]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cxx_flags(self) -> List[str]:
|
def cxx_flags(self) -> List[str]:
|
||||||
|
60
extensions/csrc/common/data_type.h
Normal file
60
extensions/csrc/common/data_type.h
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace dtype {
|
||||||
|
|
||||||
|
struct bfloat164 {
|
||||||
|
#ifdef COLOSSAL_WITH_CUDA
|
||||||
|
__nv_bfloat162 x;
|
||||||
|
__nv_bfloat162 y;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
struct bfloat168 {
|
||||||
|
#ifdef COLOSSAL_WITH_CUDA
|
||||||
|
__nv_bfloat162 x;
|
||||||
|
__nv_bfloat162 y;
|
||||||
|
__nv_bfloat162 z;
|
||||||
|
__nv_bfloat162 w;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
struct half4 {
|
||||||
|
#ifdef COLOSSAL_WITH_CUDA
|
||||||
|
half2 x;
|
||||||
|
half2 y;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
struct half8 {
|
||||||
|
#ifdef COLOSSAL_WITH_CUDA
|
||||||
|
half2 x;
|
||||||
|
half2 y;
|
||||||
|
half2 z;
|
||||||
|
half2 w;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
struct float4_ {
|
||||||
|
#ifdef COLOSSAL_WITH_CUDA
|
||||||
|
float2 x;
|
||||||
|
float2 y;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
struct float8_ {
|
||||||
|
#ifdef COLOSSAL_WITH_CUDA
|
||||||
|
float2 x;
|
||||||
|
float2 y;
|
||||||
|
float2 z;
|
||||||
|
float2 w;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dtype
|
||||||
|
} // namespace colossalAI
|
@ -222,3 +222,13 @@
|
|||||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||||
"'"); \
|
"'"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
|
#define HOST __host__
|
||||||
|
#define DEVICE __device__
|
||||||
|
#define HOSTDEVICE __host__ __device__
|
||||||
|
#else
|
||||||
|
#define HOST
|
||||||
|
#define DEVICE
|
||||||
|
#define HOSTDEVICE
|
||||||
|
#endif
|
||||||
|
@ -1,48 +1,16 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <stdint.h>
|
#endif
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
#include <cfloat>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "common/data_type.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
namespace common {
|
||||||
namespace utils {
|
|
||||||
|
|
||||||
struct bfloat164 {
|
|
||||||
__nv_bfloat162 x;
|
|
||||||
__nv_bfloat162 y;
|
|
||||||
};
|
|
||||||
struct bfloat168 {
|
|
||||||
__nv_bfloat162 x;
|
|
||||||
__nv_bfloat162 y;
|
|
||||||
__nv_bfloat162 z;
|
|
||||||
__nv_bfloat162 w;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct half4 {
|
|
||||||
half2 x;
|
|
||||||
half2 y;
|
|
||||||
};
|
|
||||||
struct half8 {
|
|
||||||
half2 x;
|
|
||||||
half2 y;
|
|
||||||
half2 z;
|
|
||||||
half2 w;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct float4_ {
|
|
||||||
float2 x;
|
|
||||||
float2 y;
|
|
||||||
};
|
|
||||||
struct float8_ {
|
|
||||||
float2 x;
|
|
||||||
float2 y;
|
|
||||||
float2 z;
|
|
||||||
float2 w;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, int VecSize>
|
template <typename T, int VecSize>
|
||||||
struct VecTypeTrait {};
|
struct VecTypeTrait {};
|
||||||
@ -57,6 +25,8 @@ struct FloatVecTypeTrait {};
|
|||||||
};
|
};
|
||||||
|
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
|
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16)
|
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162)
|
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2)
|
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2)
|
||||||
@ -67,16 +37,17 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2)
|
|||||||
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
|
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float8_)
|
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
|
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
|
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
|
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162);
|
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162);
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, bfloat164);
|
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, dtype::bfloat164);
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, bfloat168);
|
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, dtype::bfloat168);
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2);
|
VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2);
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, half4);
|
VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4);
|
||||||
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, half8);
|
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8);
|
||||||
|
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||||
|
|
||||||
#undef VEC_TYPE_TRAITS_SPECIALIZATION
|
#undef VEC_TYPE_TRAITS_SPECIALIZATION
|
||||||
|
|
||||||
@ -86,17 +57,17 @@ VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, half8);
|
|||||||
using Type = FLOATT; \
|
using Type = FLOATT; \
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2)
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2)
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4)
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4)
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2);
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2);
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(bfloat164, float4_);
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, dtype::float4_);
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(bfloat168, float8_);
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8_);
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2);
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2);
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half4, float4_);
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, dtype::float4_);
|
||||||
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half8, float8_);
|
FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8_);
|
||||||
|
#endif /* COLOSSAL_WITH_CUDA */
|
||||||
|
|
||||||
#undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION
|
#undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION
|
||||||
|
} // namespace common
|
||||||
} // namespace utils
|
|
||||||
} // namespace cuda
|
|
||||||
} // namespace colossalAI
|
} // namespace colossalAI
|
@ -1,27 +1,21 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "../utils/micros.h"
|
|
||||||
#include "../utils/vec_type_traits.h"
|
|
||||||
#include "cast_functor.h"
|
#include "cast_functor.h"
|
||||||
|
#include "common/data_type.h"
|
||||||
|
#include "common/micros.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
|
||||||
namespace funcs {
|
namespace funcs {
|
||||||
|
|
||||||
using utils::bfloat164;
|
|
||||||
using utils::bfloat168;
|
|
||||||
using utils::float4_;
|
|
||||||
using utils::float8_;
|
|
||||||
using utils::half4;
|
|
||||||
using utils::half8;
|
|
||||||
|
|
||||||
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
||||||
|
|
||||||
// Note(LiuYang): This file provides base math operation for data type
|
// Note(LiuYang): This file provides base math operation for data type
|
||||||
@ -61,6 +55,7 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
|
|||||||
STMTS_WRAPPER({ return min(lhs, rhs); }),
|
STMTS_WRAPPER({ return min(lhs, rhs); }),
|
||||||
typename T)
|
typename T)
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
|
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
|
||||||
DEVICE, STMTS_WRAPPER({
|
DEVICE, STMTS_WRAPPER({
|
||||||
return __hadd(lhs, rhs);
|
return __hadd(lhs, rhs);
|
||||||
@ -151,8 +146,9 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||||
bfloat164, bfloat164, float4_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
dtype::bfloat164, dtype::bfloat164, dtype::float4_, BinaryOpType::kMul,
|
||||||
float4_ fc;
|
DEVICE, STMTS_WRAPPER({
|
||||||
|
dtype::float4_ fc;
|
||||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||||
BinaryOpType::kMul>
|
BinaryOpType::kMul>
|
||||||
mul;
|
mul;
|
||||||
@ -162,8 +158,9 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||||
bfloat168, bfloat168, float8_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
dtype::bfloat168, dtype::bfloat168, dtype::float8_, BinaryOpType::kMul,
|
||||||
float8_ fc;
|
DEVICE, STMTS_WRAPPER({
|
||||||
|
dtype::float8_ fc;
|
||||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||||
BinaryOpType::kMul>
|
BinaryOpType::kMul>
|
||||||
mul;
|
mul;
|
||||||
@ -184,8 +181,9 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||||
half4, half4, float4_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
dtype::half4, dtype::half4, dtype::float4_, BinaryOpType::kMul, DEVICE,
|
||||||
float4_ fc;
|
STMTS_WRAPPER({
|
||||||
|
dtype::float4_ fc;
|
||||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||||
fc.x = mul(lhs.x, rhs.x);
|
fc.x = mul(lhs.x, rhs.x);
|
||||||
fc.y = mul(lhs.y, rhs.y);
|
fc.y = mul(lhs.y, rhs.y);
|
||||||
@ -193,8 +191,9 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||||
half8, half8, float8_, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
dtype::half8, dtype::half8, dtype::float8_, BinaryOpType::kMul, DEVICE,
|
||||||
float8_ fc;
|
STMTS_WRAPPER({
|
||||||
|
dtype::float8_ fc;
|
||||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||||
fc.x = mul(lhs.x, rhs.x);
|
fc.x = mul(lhs.x, rhs.x);
|
||||||
fc.y = mul(lhs.y, rhs.y);
|
fc.y = mul(lhs.y, rhs.y);
|
||||||
@ -203,10 +202,9 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fc;
|
return fc;
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||||
|
|
||||||
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
|
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
|
||||||
|
|
||||||
#undef STMTS_WRAPPER
|
#undef STMTS_WRAPPER
|
||||||
|
|
||||||
} // namespace funcs
|
} // namespace funcs
|
||||||
} // namespace cuda
|
|
||||||
} // namespace colossalAI
|
} // namespace colossalAI
|
@ -1,29 +1,23 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "../utils/micros.h"
|
#include "common/data_type.h"
|
||||||
#include "../utils/vec_type_traits.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
// Note(LiuYang): This file provides base math operation for data type
|
// Note(LiuYang): This file provides base math operation for data type
|
||||||
// include POD and cuda built-in type such as half and __nv_bfloat16
|
// include POD and cuda built-in type such as half and __nv_bfloat16
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
|
||||||
namespace funcs {
|
namespace funcs {
|
||||||
|
|
||||||
using utils::bfloat164;
|
|
||||||
using utils::bfloat168;
|
|
||||||
using utils::float4_;
|
|
||||||
using utils::float8_;
|
|
||||||
using utils::half4;
|
|
||||||
using utils::half8;
|
|
||||||
|
|
||||||
template <typename From, typename To>
|
template <typename From, typename To>
|
||||||
struct CastFunctor : public std::unary_function<From, To> {
|
struct CastFunctor : public std::unary_function<From, To> {
|
||||||
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
||||||
@ -36,6 +30,7 @@ struct CastFunctor : public std::unary_function<From, To> {
|
|||||||
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
int2, float2, { return make_float2(val.x, val.y); }, DEVICE)
|
int2, float2, { return make_float2(val.x, val.y); }, DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
@ -54,27 +49,27 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
|||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
half, float, { return __half2float(val); }, DEVICE)
|
half, float, { return __half2float(val); }, DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float4, half4,
|
float4, dtype::half4,
|
||||||
{
|
{
|
||||||
half4 dst;
|
dtype::half4 dst;
|
||||||
dst.x = __floats2half2_rn(val.x, val.y);
|
dst.x = __floats2half2_rn(val.x, val.y);
|
||||||
dst.y = __floats2half2_rn(val.z, val.w);
|
dst.y = __floats2half2_rn(val.z, val.w);
|
||||||
return dst;
|
return dst;
|
||||||
},
|
},
|
||||||
DEVICE)
|
DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float4_, half4,
|
dtype::float4_, dtype::half4,
|
||||||
{
|
{
|
||||||
half4 dst;
|
dtype::half4 dst;
|
||||||
dst.x = __float22half2_rn(val.x);
|
dst.x = __float22half2_rn(val.x);
|
||||||
dst.y = __float22half2_rn(val.y);
|
dst.y = __float22half2_rn(val.y);
|
||||||
return dst;
|
return dst;
|
||||||
},
|
},
|
||||||
DEVICE)
|
DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float8_, half8,
|
dtype::float8_, dtype::half8,
|
||||||
{
|
{
|
||||||
half8 dst;
|
dtype::half8 dst;
|
||||||
dst.x = __float22half2_rn(val.x);
|
dst.x = __float22half2_rn(val.x);
|
||||||
dst.y = __float22half2_rn(val.y);
|
dst.y = __float22half2_rn(val.y);
|
||||||
dst.z = __float22half2_rn(val.z);
|
dst.z = __float22half2_rn(val.z);
|
||||||
@ -88,9 +83,9 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
|||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE)
|
float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float4, bfloat164,
|
float4, dtype::bfloat164,
|
||||||
{
|
{
|
||||||
bfloat164 dst;
|
dtype::bfloat164 dst;
|
||||||
dst.x = __floats2bfloat162_rn(val.x, val.y);
|
dst.x = __floats2bfloat162_rn(val.x, val.y);
|
||||||
dst.y = __floats2bfloat162_rn(val.z, val.w);
|
dst.y = __floats2bfloat162_rn(val.z, val.w);
|
||||||
return dst;
|
return dst;
|
||||||
@ -105,18 +100,18 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
|||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE)
|
float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float4_, bfloat164,
|
dtype::float4_, dtype::bfloat164,
|
||||||
{
|
{
|
||||||
bfloat164 dst;
|
dtype::bfloat164 dst;
|
||||||
dst.x = __float22bfloat162_rn(val.x);
|
dst.x = __float22bfloat162_rn(val.x);
|
||||||
dst.y = __float22bfloat162_rn(val.y);
|
dst.y = __float22bfloat162_rn(val.y);
|
||||||
return dst;
|
return dst;
|
||||||
},
|
},
|
||||||
DEVICE)
|
DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float8_, bfloat168,
|
dtype::float8_, dtype::bfloat168,
|
||||||
{
|
{
|
||||||
bfloat168 dst;
|
dtype::bfloat168 dst;
|
||||||
dst.x = __float22bfloat162_rn(val.x);
|
dst.x = __float22bfloat162_rn(val.x);
|
||||||
dst.y = __float22bfloat162_rn(val.y);
|
dst.y = __float22bfloat162_rn(val.y);
|
||||||
dst.z = __float22bfloat162_rn(val.z);
|
dst.z = __float22bfloat162_rn(val.z);
|
||||||
@ -141,18 +136,18 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
|||||||
float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); },
|
float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); },
|
||||||
DEVICE)
|
DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float4_, bfloat164,
|
dtype::float4_, dtype::bfloat164,
|
||||||
{
|
{
|
||||||
bfloat164 dst;
|
dtype::bfloat164 dst;
|
||||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||||
return dst;
|
return dst;
|
||||||
},
|
},
|
||||||
DEVICE)
|
DEVICE)
|
||||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||||
float8_, bfloat168,
|
dtype::float8_, dtype::bfloat168,
|
||||||
{
|
{
|
||||||
bfloat168 dst;
|
dtype::bfloat168 dst;
|
||||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||||
dst.z = __floats2bfloat162_rn(val.z.x, val.z.y);
|
dst.z = __floats2bfloat162_rn(val.z.x, val.z.y);
|
||||||
@ -161,8 +156,8 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
|||||||
},
|
},
|
||||||
DEVICE)
|
DEVICE)
|
||||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||||
|
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||||
|
|
||||||
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
|
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
|
||||||
} // namespace funcs
|
} // namespace funcs
|
||||||
} // namespace cuda
|
|
||||||
} // namespace colossalAI
|
} // namespace colossalAI
|
@ -1,13 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include "../funcs/binary_functor.h"
|
#include "binary_functor.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
|
||||||
namespace funcs {
|
namespace funcs {
|
||||||
|
|
||||||
const float kReduceFloatInfNeg = -100000000.f;
|
const float kReduceFloatInfNeg = -100000000.f;
|
||||||
@ -89,5 +89,6 @@ __forceinline__ __device__ void block_reduce(T* pval) {
|
|||||||
#undef COLOSSAL_BLOCK_REDUCE_IMPL
|
#undef COLOSSAL_BLOCK_REDUCE_IMPL
|
||||||
|
|
||||||
} // namespace funcs
|
} // namespace funcs
|
||||||
} // namespace cuda
|
|
||||||
} // namespace colossalAI
|
} // namespace colossalAI
|
||||||
|
|
||||||
|
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
@ -1,18 +1,20 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "../funcs/cast_functor.h"
|
#include "cast_functor.h"
|
||||||
#include "../utils/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
|
||||||
namespace funcs {
|
namespace funcs {
|
||||||
|
|
||||||
enum class TernaryOpType { kFma = 0 };
|
enum class TernaryOpType { kFma = 0 };
|
||||||
@ -29,6 +31,7 @@ struct TernaryOpFunctor;
|
|||||||
FUNCTION_MODIFIER RET operator()(LT a, RT b, RET c) STMTS \
|
FUNCTION_MODIFIER RET operator()(LT a, RT b, RET c) STMTS \
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float, float,
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float, float,
|
||||||
TernaryOpType::kFma, DEVICE,
|
TernaryOpType::kFma, DEVICE,
|
||||||
STMTS_WRAPPER({
|
STMTS_WRAPPER({
|
||||||
@ -91,16 +94,18 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fma(cast(a), b, c);
|
return fma(cast(a), b, c);
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
half4, half4, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||||
float4_ fd;
|
STMTS_WRAPPER({
|
||||||
|
dtype::float4_ fd;
|
||||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||||
fd.x = fma(a.x, b.x, c.x);
|
fd.x = fma(a.x, b.x, c.x);
|
||||||
fd.y = fma(a.y, b.y, c.y);
|
fd.y = fma(a.y, b.y, c.y);
|
||||||
return fd;
|
return fd;
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
half, half4, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||||
float4_ fd;
|
STMTS_WRAPPER({
|
||||||
|
dtype::float4_ fd;
|
||||||
CastFunctor<half, half2> cast;
|
CastFunctor<half, half2> cast;
|
||||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||||
half2 s = cast(a);
|
half2 s = cast(a);
|
||||||
@ -109,8 +114,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fd;
|
return fd;
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
half8, half8, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||||
float8_ fd;
|
STMTS_WRAPPER({
|
||||||
|
dtype::float8_ fd;
|
||||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||||
fd.x = fma(a.x, b.x, c.x);
|
fd.x = fma(a.x, b.x, c.x);
|
||||||
fd.y = fma(a.y, b.y, c.y);
|
fd.y = fma(a.y, b.y, c.y);
|
||||||
@ -119,8 +125,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fd;
|
return fd;
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
half, half8, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||||
float8_ fd;
|
STMTS_WRAPPER({
|
||||||
|
dtype::float8_ fd;
|
||||||
CastFunctor<half, half2> cast;
|
CastFunctor<half, half2> cast;
|
||||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||||
half2 s = cast(a);
|
half2 s = cast(a);
|
||||||
@ -153,8 +160,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fma(cast(a), b, c);
|
return fma(cast(a), b, c);
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
bfloat164, bfloat164, float4_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||||
float4_ fd;
|
DEVICE, STMTS_WRAPPER({
|
||||||
|
dtype::float4_ fd;
|
||||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||||
TernaryOpType::kFma>
|
TernaryOpType::kFma>
|
||||||
fma;
|
fma;
|
||||||
@ -163,9 +171,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fd;
|
return fd;
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
__nv_bfloat16, bfloat164, float4_, TernaryOpType::kFma, DEVICE,
|
__nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||||
STMTS_WRAPPER({
|
DEVICE, STMTS_WRAPPER({
|
||||||
float4_ fd;
|
dtype::float4_ fd;
|
||||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||||
TernaryOpType::kFma>
|
TernaryOpType::kFma>
|
||||||
@ -176,8 +184,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fd;
|
return fd;
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
bfloat168, bfloat168, float8_, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||||
float8_ fd;
|
DEVICE, STMTS_WRAPPER({
|
||||||
|
dtype::float8_ fd;
|
||||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||||
TernaryOpType::kFma>
|
TernaryOpType::kFma>
|
||||||
fma;
|
fma;
|
||||||
@ -188,9 +197,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fd;
|
return fd;
|
||||||
}))
|
}))
|
||||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||||
__nv_bfloat16, bfloat168, float8_, TernaryOpType::kFma, DEVICE,
|
__nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||||
STMTS_WRAPPER({
|
DEVICE, STMTS_WRAPPER({
|
||||||
float8_ fd;
|
dtype::float8_ fd;
|
||||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||||
TernaryOpType::kFma>
|
TernaryOpType::kFma>
|
||||||
@ -203,10 +212,10 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
|||||||
return fd;
|
return fd;
|
||||||
}))
|
}))
|
||||||
|
|
||||||
#undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION
|
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||||
|
|
||||||
|
#undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION
|
||||||
#undef STMTS_WRAPPER
|
#undef STMTS_WRAPPER
|
||||||
|
|
||||||
} // namespace funcs
|
} // namespace funcs
|
||||||
} // namespace cuda
|
|
||||||
} // namespace colossalAI
|
} // namespace colossalAI
|
@ -1,16 +1,18 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "../utils/micros.h"
|
#include "common/data_type.h"
|
||||||
|
#include "common/micros.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
|
||||||
namespace funcs {
|
namespace funcs {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -57,27 +59,30 @@ COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil,
|
|||||||
return log2_value;
|
return log2_value;
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#if defined(COLOSSAL_WITH_CUDA)
|
||||||
|
|
||||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE,
|
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE,
|
||||||
{ return val.x + val.y; })
|
{ return val.x + val.y; })
|
||||||
|
|
||||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE,
|
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE,
|
||||||
{ return val.x + val.y + val.z + val.w; })
|
{ return val.x + val.y + val.z + val.w; })
|
||||||
|
|
||||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4_, float, UnaryOpType::kSum, DEVICE,
|
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float4_, float, UnaryOpType::kSum,
|
||||||
{
|
DEVICE, {
|
||||||
return val.x.x + val.x.y + val.y.x +
|
return val.x.x + val.x.y + val.y.x +
|
||||||
val.y.y;
|
val.y.y;
|
||||||
})
|
})
|
||||||
|
|
||||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float8_, float, UnaryOpType::kSum, DEVICE,
|
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8_, float, UnaryOpType::kSum,
|
||||||
{
|
DEVICE, {
|
||||||
return val.x.x + val.x.y + val.y.x +
|
return val.x.x + val.x.y + val.y.x +
|
||||||
val.y.y + val.z.x + val.z.y +
|
val.y.y + val.z.x + val.z.y +
|
||||||
val.w.x + val.w.y;
|
val.w.x + val.w.y;
|
||||||
})
|
})
|
||||||
|
|
||||||
|
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||||
|
|
||||||
#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION
|
#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION
|
||||||
|
|
||||||
} // namespace funcs
|
} // namespace funcs
|
||||||
} // namespace cuda
|
|
||||||
} // namespace colossalAI
|
} // namespace colossalAI
|
@ -2,13 +2,15 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "../common/mp_type_traits.h"
|
#include "common/mp_type_traits.h"
|
||||||
|
|
||||||
|
using colossalAI::common::MPTypeTrait;
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||||
// x * sigmoid(x)
|
// x * sigmoid(x)
|
||||||
using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
|
using MT = typename MPTypeTrait<T>::Type;
|
||||||
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
|
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -17,7 +19,7 @@ __global__ void act_and_mul_kernel(
|
|||||||
const scalar_t* __restrict__ ins_data,
|
const scalar_t* __restrict__ ins_data,
|
||||||
scalar_t* __restrict__ outs_data,
|
scalar_t* __restrict__ outs_data,
|
||||||
const int64_t numel) {
|
const int64_t numel) {
|
||||||
using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
|
using MT = typename MPTypeTrait<scalar_t>::Type;
|
||||||
|
|
||||||
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
||||||
const int64_t grid_size = blockDim.x * gridDim.x;
|
const int64_t grid_size = blockDim.x * gridDim.x;
|
@ -23,24 +23,16 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
|
|
||||||
#include "../funcs/binary_functor.h"
|
#include "common/vec_type_traits.h"
|
||||||
#include "../funcs/cast_functor.h"
|
#include "funcs/binary_functor.h"
|
||||||
#include "../funcs/ternary_functor.h"
|
#include "funcs/cast_functor.h"
|
||||||
#include "../funcs/unary_functor.h"
|
#include "funcs/ternary_functor.h"
|
||||||
#include "../utils/vec_type_traits.h"
|
#include "funcs/unary_functor.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
namespace attention {
|
namespace attention {
|
||||||
|
|
||||||
using colossalAI::cuda::funcs::BinaryOpFunctor;
|
|
||||||
using colossalAI::cuda::funcs::BinaryOpType;
|
|
||||||
using colossalAI::cuda::funcs::TernaryOpFunctor;
|
|
||||||
using colossalAI::cuda::funcs::TernaryOpType;
|
|
||||||
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
|
||||||
using colossalAI::cuda::funcs::UnaryOpType;
|
|
||||||
using colossalAI::cuda::utils::FloatVecTypeTrait;
|
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
#define VEC_SIZE_8 8
|
#define VEC_SIZE_8 8
|
||||||
|
|
||||||
@ -51,11 +43,11 @@ using colossalAI::cuda::utils::FloatVecTypeTrait;
|
|||||||
// Q*K^T operation.
|
// Q*K^T operation.
|
||||||
template <int NUM_THREADS_PER_TOKEN, typename VecT, int N>
|
template <int NUM_THREADS_PER_TOKEN, typename VecT, int N>
|
||||||
inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
|
inline __device__ float qk_dot_(const VecT (&q)[N], const VecT (&k)[N]) {
|
||||||
using A_vec = typename FloatVecTypeTrait<VecT>::Type;
|
using A_vec = typename common::FloatVecTypeTrait<VecT>::Type;
|
||||||
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||||
BinaryOpFunctor<VecT, VecT, A_vec, BinaryOpType::kMul> mul_vect;
|
funcs::BinaryOpFunctor<VecT, VecT, A_vec, funcs::BinaryOpType::kMul> mul_vect;
|
||||||
UnaryOpFunctor<A_vec, float, UnaryOpType::kSum> sum_vect;
|
funcs::UnaryOpFunctor<A_vec, float, funcs::UnaryOpType::kSum> sum_vect;
|
||||||
TernaryOpFunctor<VecT, VecT, A_vec, TernaryOpType::kFma> fma;
|
funcs::TernaryOpFunctor<VecT, VecT, A_vec, funcs::TernaryOpType::kFma> fma;
|
||||||
|
|
||||||
A_vec qk_vec = mul_vect(q[0], k[0]);
|
A_vec qk_vec = mul_vect(q[0], k[0]);
|
||||||
#pragma unroll
|
#pragma unroll
|
@ -2,7 +2,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vec_copy.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy_vector;
|
||||||
using colossalAI::cuda::utils::get_vec_size;
|
using colossalAI::cuda::utils::get_vec_size;
|
@ -2,7 +2,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vec_copy.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy_vector;
|
||||||
using colossalAI::cuda::utils::get_vec_size;
|
using colossalAI::cuda::utils::get_vec_size;
|
@ -7,11 +7,11 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "funcs/cast_functor.h"
|
#include "funcs/cast_functor.h"
|
||||||
#include "funcs/ternary_functor.h"
|
#include "funcs/ternary_functor.h"
|
||||||
#include "funcs/binary_functor.h"
|
#include "funcs/binary_functor.h"
|
||||||
#include "utils/vec_type_traits.h"
|
#include "common/vec_type_traits.h"
|
||||||
#include "attention/attention_utils.h"
|
#include "attention/attention_utils.h"
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
@ -34,13 +34,13 @@ constexpr unsigned int nextHighestPowerOf2(unsigned int v) {
|
|||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
using colossalAI::cuda::funcs::BinaryOpType;
|
using colossalAI::funcs::BinaryOpType;
|
||||||
using colossalAI::cuda::funcs::CastFunctor;
|
using colossalAI::funcs::CastFunctor;
|
||||||
using colossalAI::cuda::funcs::TernaryOpFunctor;
|
using colossalAI::funcs::TernaryOpFunctor;
|
||||||
using colossalAI::cuda::funcs::TernaryOpType;
|
using colossalAI::funcs::TernaryOpType;
|
||||||
using colossalAI::cuda::funcs::zero;
|
using colossalAI::funcs::zero;
|
||||||
using colossalAI::cuda::utils::VecTypeTrait;
|
using colossalAI::common::VecTypeTrait;
|
||||||
using colossalAI::cuda::utils::FloatVecTypeTrait;
|
using colossalAI::common::FloatVecTypeTrait;
|
||||||
using namespace colossalAI::cuda::attention;
|
using namespace colossalAI::cuda::attention;
|
||||||
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vec_copy.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "../common/mp_type_traits.h"
|
#include "common/mp_type_traits.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy_vector;
|
||||||
using colossalAI::cuda::utils::get_vec_size;
|
using colossalAI::cuda::utils::get_vec_size;
|
@ -2,7 +2,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "utils/vec_copy.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy_vector;
|
||||||
using colossalAI::cuda::utils::get_vec_size;
|
using colossalAI::cuda::utils::get_vec_size;
|
@ -9,7 +9,7 @@
|
|||||||
#include "ATen/AccumulateType.h"
|
#include "ATen/AccumulateType.h"
|
||||||
#include "ATen/cuda/CUDAContext.h"
|
#include "ATen/cuda/CUDAContext.h"
|
||||||
#include "ATen/cuda/DeviceUtils.cuh"
|
#include "ATen/cuda/DeviceUtils.cuh"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
@ -6,9 +6,8 @@
|
|||||||
|
|
||||||
#include "funcs/reduce_function.h"
|
#include "funcs/reduce_function.h"
|
||||||
|
|
||||||
|
using colossalAI::funcs::block_reduce;
|
||||||
using colossalAI::cuda::funcs::block_reduce;
|
using colossalAI::funcs::ReduceType;
|
||||||
using colossalAI::cuda::funcs::ReduceType;
|
|
||||||
|
|
||||||
template <typename T, int block_size, int pack_size>
|
template <typename T, int block_size, int pack_size>
|
||||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||||
@ -540,7 +539,7 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
|||||||
|
|
||||||
// API FUNCTIONS --------------------------------
|
// API FUNCTIONS --------------------------------
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
#define DISPATCH_FLOAT_AND_HALF_MOE(TYPE, NAME, ...) \
|
||||||
switch (TYPE) { \
|
switch (TYPE) { \
|
||||||
case at::ScalarType::Float: { \
|
case at::ScalarType::Float: { \
|
||||||
using scalar_t = float; \
|
using scalar_t = float; \
|
||||||
@ -566,7 +565,7 @@ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
|||||||
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
|
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
|
||||||
auto k = mask.size(0);
|
auto k = mask.size(0);
|
||||||
|
|
||||||
DISPATCH_FLOAT_AND_HALF(
|
DISPATCH_FLOAT_AND_HALF_MOE(
|
||||||
batch_tokens.scalar_type(), "moe dispatch forward",
|
batch_tokens.scalar_type(), "moe dispatch forward",
|
||||||
moe_dpch_fwd_launch<scalar_t>(
|
moe_dpch_fwd_launch<scalar_t>(
|
||||||
batch_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
|
batch_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
|
||||||
@ -586,7 +585,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
|||||||
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
||||||
auto k = mask.size(0);
|
auto k = mask.size(0);
|
||||||
|
|
||||||
DISPATCH_FLOAT_AND_HALF(
|
DISPATCH_FLOAT_AND_HALF_MOE(
|
||||||
expert_grad.scalar_type(), "moe dispatch backward",
|
expert_grad.scalar_type(), "moe dispatch backward",
|
||||||
moe_dpch_bwd_launch<scalar_t>(
|
moe_dpch_bwd_launch<scalar_t>(
|
||||||
res.data_ptr<scalar_t>(), expert_grad.data_ptr<scalar_t>(),
|
res.data_ptr<scalar_t>(), expert_grad.data_ptr<scalar_t>(),
|
||||||
@ -609,7 +608,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|||||||
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
|
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
|
||||||
auto k = mask.size(0);
|
auto k = mask.size(0);
|
||||||
|
|
||||||
DISPATCH_FLOAT_AND_HALF(
|
DISPATCH_FLOAT_AND_HALF_MOE(
|
||||||
expert_tokens.scalar_type(), "moe combine forward",
|
expert_tokens.scalar_type(), "moe combine forward",
|
||||||
moe_cb_fwd_launch<scalar_t>(
|
moe_cb_fwd_launch<scalar_t>(
|
||||||
expert_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
|
expert_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
|
||||||
@ -636,7 +635,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
|
|||||||
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
|
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
|
||||||
auto k = mask.size(0);
|
auto k = mask.size(0);
|
||||||
|
|
||||||
DISPATCH_FLOAT_AND_HALF(
|
DISPATCH_FLOAT_AND_HALF_MOE(
|
||||||
tokens_grad.scalar_type(), "moe combine backward",
|
tokens_grad.scalar_type(), "moe combine backward",
|
||||||
moe_cb_bwd_launch<scalar_t>(
|
moe_cb_bwd_launch<scalar_t>(
|
||||||
tokens_grad.data_ptr<scalar_t>(), egrad.data_ptr<scalar_t>(),
|
tokens_grad.data_ptr<scalar_t>(), egrad.data_ptr<scalar_t>(),
|
@ -15,7 +15,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
@ -12,7 +12,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
// #include <iostream>
|
// #include <iostream>
|
||||||
|
|
@ -11,8 +11,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "funcs/reduce_function.h"
|
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
@ -10,7 +10,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
@ -10,7 +10,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
@ -7,7 +7,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
@ -7,18 +7,18 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "funcs/cast_functor.h"
|
#include "funcs/cast_functor.h"
|
||||||
#include "funcs/binary_functor.h"
|
#include "funcs/binary_functor.h"
|
||||||
#include "funcs/reduce_function.h"
|
#include "funcs/reduce_function.h"
|
||||||
#include "utils/vec_type_traits.h"
|
#include "common/vec_type_traits.h"
|
||||||
|
|
||||||
using colossalAI::cuda::funcs::block_reduce;
|
using colossalAI::funcs::block_reduce;
|
||||||
using colossalAI::cuda::funcs::ReduceType;
|
using colossalAI::funcs::ReduceType;
|
||||||
using colossalAI::cuda::funcs::CastFunctor;
|
using colossalAI::funcs::CastFunctor;
|
||||||
using colossalAI::cuda::funcs::BinaryOpFunctor;
|
using colossalAI::funcs::BinaryOpFunctor;
|
||||||
using colossalAI::cuda::funcs::BinaryOpType;
|
using colossalAI::funcs::BinaryOpType;
|
||||||
using colossalAI::cuda::utils::VecTypeTrait;
|
using colossalAI::common::VecTypeTrait;
|
||||||
|
|
||||||
#define RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \
|
#define RMSNORM_LAUNCHER(UNROLL_FACTOR, THREADDIM) \
|
||||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( \
|
@ -14,15 +14,15 @@
|
|||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "utils/vec_copy.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "funcs/reduce_function.h"
|
#include "funcs/reduce_function.h"
|
||||||
#include "funcs/unary_functor.h"
|
#include "funcs/unary_functor.h"
|
||||||
|
|
||||||
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
using colossalAI::funcs::UnaryOpFunctor;
|
||||||
using colossalAI::cuda::funcs::UnaryOpType;
|
using colossalAI::funcs::UnaryOpType;
|
||||||
using colossalAI::cuda::funcs::warp_reduce;
|
using colossalAI::funcs::warp_reduce;
|
||||||
using colossalAI::cuda::funcs::ReduceType;
|
using colossalAI::funcs::ReduceType;
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy_vector;
|
||||||
|
|
||||||
|
|
@ -14,15 +14,15 @@
|
|||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "common/micros.h"
|
||||||
#include "utils/vec_copy.h"
|
#include "utils/vec_copy.h"
|
||||||
#include "funcs/reduce_function.h"
|
#include "funcs/reduce_function.h"
|
||||||
#include "funcs/unary_functor.h"
|
#include "funcs/unary_functor.h"
|
||||||
|
|
||||||
using colossalAI::cuda::funcs::UnaryOpFunctor;
|
using colossalAI::funcs::UnaryOpFunctor;
|
||||||
using colossalAI::cuda::funcs::UnaryOpType;
|
using colossalAI::funcs::UnaryOpType;
|
||||||
using colossalAI::cuda::funcs::warp_reduce;
|
using colossalAI::funcs::warp_reduce;
|
||||||
using colossalAI::cuda::funcs::ReduceType;
|
using colossalAI::funcs::ReduceType;
|
||||||
using colossalAI::cuda::utils::copy_vector;
|
using colossalAI::cuda::utils::copy_vector;
|
||||||
using colossalAI::cuda::utils::copy_zero_vector;
|
using colossalAI::cuda::utils::copy_zero_vector;
|
||||||
|
|
@ -4,8 +4,8 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "../funcs/cast_functor.h"
|
#include "common/vec_type_traits.h"
|
||||||
#include "vec_type_traits.h"
|
#include "funcs/cast_functor.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
@ -13,7 +13,7 @@ namespace utils {
|
|||||||
|
|
||||||
template <typename T, int VecSize>
|
template <typename T, int VecSize>
|
||||||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
||||||
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
|
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
||||||
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
|
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
|
||||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||||
}
|
}
|
||||||
@ -29,9 +29,8 @@ __device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
|||||||
|
|
||||||
template <typename T, int VecSize>
|
template <typename T, int VecSize>
|
||||||
__device__ __inline__ void copy_zero_vector(T *dst) {
|
__device__ __inline__ void copy_zero_vector(T *dst) {
|
||||||
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
|
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
|
||||||
*(reinterpret_cast<VT *>(dst)) =
|
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
|
||||||
colossalAI::cuda::funcs::CastFunctor<float, VT>()(0.0f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
@ -21,6 +21,7 @@ class _CudaExtension(_CppExtension):
|
|||||||
"""
|
"""
|
||||||
This function should return a list of nvcc compilation flags for extensions.
|
This function should return a list of nvcc compilation flags for extensions.
|
||||||
"""
|
"""
|
||||||
|
return ["-DCOLOSSAL_WITH_CUDA"]
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
def is_available(self) -> bool:
|
||||||
# cuda extension can only be built if cuda is available
|
# cuda extension can only be built if cuda is available
|
||||||
@ -53,6 +54,12 @@ class _CudaExtension(_CppExtension):
|
|||||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||||
return cuda_include
|
return cuda_include
|
||||||
|
|
||||||
|
def include_dirs(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
This function should return a list of include files for extensions.
|
||||||
|
"""
|
||||||
|
return super().include_dirs() + [self.get_cuda_home_include()]
|
||||||
|
|
||||||
def build_jit(self) -> None:
|
def build_jit(self) -> None:
|
||||||
from torch.utils.cpp_extension import CUDA_HOME, load
|
from torch.utils.cpp_extension import CUDA_HOME, load
|
||||||
|
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
from ..cuda_extension import _CudaExtension
|
|
||||||
from ..utils import get_cuda_cc_flag
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceOpsCudaExtension(_CudaExtension):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(name="inference_ops_cuda")
|
|
||||||
|
|
||||||
def sources_files(self):
|
|
||||||
ret = [
|
|
||||||
self.csrc_abs_path(fname)
|
|
||||||
for fname in [
|
|
||||||
"cuda/pybind/inference.cpp",
|
|
||||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
|
||||||
"cuda/context_kv_cache_memcpy_kernel.cu",
|
|
||||||
"cuda/fused_rotary_emb_and_cache_kernel.cu",
|
|
||||||
"cuda/activation_kernel.cu",
|
|
||||||
"cuda/rms_layernorm_kernel.cu",
|
|
||||||
"cuda/get_cos_and_sin_kernel.cu",
|
|
||||||
"cuda/flash_decoding_attention_kernel.cu",
|
|
||||||
]
|
|
||||||
]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def include_dirs(self):
|
|
||||||
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def cxx_flags(self):
|
|
||||||
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
|
||||||
return ["-O3"] + version_dependent_macros
|
|
||||||
|
|
||||||
def nvcc_flags(self):
|
|
||||||
extra_cuda_flags = ["-lineinfo"]
|
|
||||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
|
||||||
return ["-O3", "--use_fast_math"] + extra_cuda_flags
|
|
0
extensions/pybind/__init__.py
Normal file
0
extensions/pybind/__init__.py
Normal file
@ -1,6 +1,7 @@
|
|||||||
import platform
|
import platform
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from ..cpp_extension import _CppExtension
|
from ...cpp_extension import _CppExtension
|
||||||
|
|
||||||
|
|
||||||
class CpuAdamArmExtension(_CppExtension):
|
class CpuAdamArmExtension(_CppExtension):
|
||||||
@ -20,12 +21,12 @@ class CpuAdamArmExtension(_CppExtension):
|
|||||||
# necessary 4 functions
|
# necessary 4 functions
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path("arm/cpu_adam_arm.cpp"),
|
self.csrc_abs_path("kernel/arm/cpu_adam_arm.cpp"),
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
def include_dirs(self) -> List[str]:
|
||||||
return []
|
return super().include_dirs()
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
extra_cxx_flags = [
|
extra_cxx_flags = [
|
@ -1,7 +1,7 @@
|
|||||||
import platform
|
import platform
|
||||||
|
|
||||||
from ..cuda_extension import _CudaExtension
|
from ...cuda_extension import _CudaExtension
|
||||||
from ..utils import append_nvcc_threads
|
from ...utils import append_nvcc_threads
|
||||||
|
|
||||||
|
|
||||||
class CpuAdamX86Extension(_CudaExtension):
|
class CpuAdamX86Extension(_CudaExtension):
|
||||||
@ -21,13 +21,10 @@ class CpuAdamX86Extension(_CudaExtension):
|
|||||||
# necessary 4 functions
|
# necessary 4 functions
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path("x86/cpu_adam.cpp"),
|
self.csrc_abs_path("kernel/x86/cpu_adam.cpp"),
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
|
||||||
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
|
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
extra_cxx_flags = [
|
extra_cxx_flags = [
|
||||||
"-std=c++14",
|
"-std=c++14",
|
||||||
@ -50,5 +47,5 @@ class CpuAdamX86Extension(_CudaExtension):
|
|||||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||||
]
|
]
|
||||||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
|
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + super().nvcc_flags()
|
||||||
return append_nvcc_threads(ret)
|
return append_nvcc_threads(ret)
|
@ -1,4 +1,4 @@
|
|||||||
from ..base_extension import _Extension
|
from ...base_extension import _Extension
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionDaoCudaExtension(_Extension):
|
class FlashAttentionDaoCudaExtension(_Extension):
|
@ -1,4 +1,4 @@
|
|||||||
from ..base_extension import _Extension
|
from ...base_extension import _Extension
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionNpuExtension(_Extension):
|
class FlashAttentionNpuExtension(_Extension):
|
@ -1,4 +1,4 @@
|
|||||||
from ..base_extension import _Extension
|
from ...base_extension import _Extension
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionSdpaCudaExtension(_Extension):
|
class FlashAttentionSdpaCudaExtension(_Extension):
|
31
extensions/pybind/inference/inference_ops_cuda.py
Normal file
31
extensions/pybind/inference/inference_ops_cuda.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from ...cuda_extension import _CudaExtension
|
||||||
|
from ...utils import get_cuda_cc_flag
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceOpsCudaExtension(_CudaExtension):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(name="inference_ops_cuda")
|
||||||
|
|
||||||
|
def sources_files(self):
|
||||||
|
ret = [
|
||||||
|
self.csrc_abs_path(fname)
|
||||||
|
for fname in [
|
||||||
|
"kernel/cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||||
|
"kernel/cuda/context_kv_cache_memcpy_kernel.cu",
|
||||||
|
"kernel/cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||||
|
"kernel/cuda/activation_kernel.cu",
|
||||||
|
"kernel/cuda/rms_layernorm_kernel.cu",
|
||||||
|
"kernel/cuda/get_cos_and_sin_kernel.cu",
|
||||||
|
"kernel/cuda/flash_decoding_attention_kernel.cu",
|
||||||
|
]
|
||||||
|
] + [self.pybind_abs_path("inference/inference.cpp")]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def cxx_flags(self):
|
||||||
|
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||||
|
return ["-O3"] + version_dependent_macros
|
||||||
|
|
||||||
|
def nvcc_flags(self):
|
||||||
|
extra_cuda_flags = ["-lineinfo"]
|
||||||
|
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||||
|
return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
|
@ -7,7 +7,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../../common/micros.h"
|
#include "common/micros.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
from ..cuda_extension import _CudaExtension
|
from ...cuda_extension import _CudaExtension
|
||||||
from ..utils import append_nvcc_threads, get_cuda_cc_flag
|
from ...utils import append_nvcc_threads, get_cuda_cc_flag
|
||||||
|
|
||||||
|
|
||||||
class LayerNormCudaExtension(_CudaExtension):
|
class LayerNormCudaExtension(_CudaExtension):
|
||||||
@ -7,11 +7,13 @@ class LayerNormCudaExtension(_CudaExtension):
|
|||||||
super().__init__(name="layernorm_cuda")
|
super().__init__(name="layernorm_cuda")
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/layer_norm.cpp", "cuda/layer_norm_kernel.cu"]]
|
ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/layer_norm_kernel.cu"]] + [
|
||||||
|
self.pybind_abs_path("layernorm/layer_norm.cpp")
|
||||||
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
def include_dirs(self):
|
||||||
ret = [self.get_cuda_home_include()]
|
ret = [self.get_cuda_home_include()] + [self.csrc_abs_path("")]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
@ -20,5 +22,5 @@ class LayerNormCudaExtension(_CudaExtension):
|
|||||||
def nvcc_flags(self):
|
def nvcc_flags(self):
|
||||||
extra_cuda_flags = ["-maxrregcount=50"]
|
extra_cuda_flags = ["-maxrregcount=50"]
|
||||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||||
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros
|
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + super().nvcc_flags()
|
||||||
return append_nvcc_threads(ret)
|
return append_nvcc_threads(ret)
|
@ -1,17 +1,15 @@
|
|||||||
from ..cuda_extension import _CudaExtension
|
from ...cuda_extension import _CudaExtension
|
||||||
from ..utils import append_nvcc_threads, get_cuda_cc_flag
|
from ...utils import append_nvcc_threads, get_cuda_cc_flag
|
||||||
|
|
||||||
|
|
||||||
class MoeCudaExtension(_CudaExtension):
|
class MoeCudaExtension(_CudaExtension):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(name="moe_cuda")
|
super().__init__(name="moe_cuda")
|
||||||
|
|
||||||
def include_dirs(self):
|
|
||||||
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/moe.cpp", "cuda/moe_kernel.cu"]]
|
ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/moe_kernel.cu"]] + [
|
||||||
|
self.pybind_abs_path("moe/moe.cpp")
|
||||||
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
@ -25,5 +23,5 @@ class MoeCudaExtension(_CudaExtension):
|
|||||||
"--expt-extended-lambda",
|
"--expt-extended-lambda",
|
||||||
]
|
]
|
||||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||||
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags
|
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
|
||||||
return append_nvcc_threads(ret)
|
return append_nvcc_threads(ret)
|
@ -1,5 +1,5 @@
|
|||||||
from ..cuda_extension import _CudaExtension
|
from ...cuda_extension import _CudaExtension
|
||||||
from ..utils import get_cuda_cc_flag
|
from ...utils import get_cuda_cc_flag
|
||||||
|
|
||||||
|
|
||||||
class FusedOptimizerCudaExtension(_CudaExtension):
|
class FusedOptimizerCudaExtension(_CudaExtension):
|
||||||
@ -10,18 +10,13 @@ class FusedOptimizerCudaExtension(_CudaExtension):
|
|||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/pybind/optimizer.cpp",
|
"kernel/cuda/multi_tensor_sgd_kernel.cu",
|
||||||
"cuda/multi_tensor_sgd_kernel.cu",
|
"kernel/cuda/multi_tensor_scale_kernel.cu",
|
||||||
"cuda/multi_tensor_scale_kernel.cu",
|
"kernel/cuda/multi_tensor_adam_kernel.cu",
|
||||||
"cuda/multi_tensor_adam_kernel.cu",
|
"kernel/cuda/multi_tensor_l2norm_kernel.cu",
|
||||||
"cuda/multi_tensor_l2norm_kernel.cu",
|
"kernel/cuda/multi_tensor_lamb_kernel.cu",
|
||||||
"cuda/multi_tensor_lamb_kernel.cu",
|
|
||||||
]
|
]
|
||||||
]
|
] + [self.pybind_abs_path("optimizer/optimizer.cpp")]
|
||||||
return ret
|
|
||||||
|
|
||||||
def include_dirs(self):
|
|
||||||
ret = [self.get_cuda_home_include()]
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
@ -31,4 +26,4 @@ class FusedOptimizerCudaExtension(_CudaExtension):
|
|||||||
def nvcc_flags(self):
|
def nvcc_flags(self):
|
||||||
extra_cuda_flags = ["-lineinfo"]
|
extra_cuda_flags = ["-lineinfo"]
|
||||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||||
return ["-O3", "--use_fast_math"] + extra_cuda_flags
|
return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
|
@ -1,5 +1,5 @@
|
|||||||
from ..cuda_extension import _CudaExtension
|
from ...cuda_extension import _CudaExtension
|
||||||
from ..utils import append_nvcc_threads
|
from ...utils import append_nvcc_threads
|
||||||
|
|
||||||
|
|
||||||
class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||||
@ -7,15 +7,11 @@ class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
|||||||
super().__init__(name="scaled_masked_softmax_cuda")
|
super().__init__(name="scaled_masked_softmax_cuda")
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/scaled_masked_softmax_kernel.cu"]] + [
|
||||||
self.csrc_abs_path(fname)
|
self.pybind_abs_path("softmax/scaled_masked_softmax.cpp")
|
||||||
for fname in ["cuda/pybind/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_kernel.cu"]
|
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
|
||||||
return [self.get_cuda_home_include()]
|
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
return ["-O3"] + self.version_dependent_macros
|
return ["-O3"] + self.version_dependent_macros
|
||||||
|
|
||||||
@ -28,5 +24,5 @@ class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
|||||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||||
]
|
]
|
||||||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
|
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + super().nvcc_flags()
|
||||||
return append_nvcc_threads(ret)
|
return append_nvcc_threads(ret)
|
@ -1,22 +1,18 @@
|
|||||||
from ..cuda_extension import _CudaExtension
|
from ...cuda_extension import _CudaExtension
|
||||||
from ..utils import append_nvcc_threads, get_cuda_cc_flag
|
from ...utils import append_nvcc_threads, get_cuda_cc_flag
|
||||||
|
|
||||||
|
|
||||||
class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
|
class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(name="scaled_upper_triangle_masked_softmax_cuda")
|
super().__init__(name="scaled_upper_triangle_masked_softmax_cuda")
|
||||||
|
|
||||||
def include_dirs(self):
|
|
||||||
return [self.get_cuda_home_include()]
|
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/pybind/scaled_upper_triang_masked_softmax.cpp",
|
"kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu",
|
||||||
"cuda/scaled_upper_triang_masked_softmax_kernel.cu",
|
|
||||||
]
|
|
||||||
]
|
]
|
||||||
|
] + [self.pybind_abs_path("softmax/scaled_upper_triang_masked_softmax.cpp")]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
@ -30,5 +26,5 @@ class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
|
|||||||
"--expt-extended-lambda",
|
"--expt-extended-lambda",
|
||||||
]
|
]
|
||||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||||
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags
|
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
|
||||||
return append_nvcc_threads(ret)
|
return append_nvcc_threads(ret)
|
Loading…
Reference in New Issue
Block a user