[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee 2024-01-25 17:01:48 +08:00 committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
157 changed files with 1353 additions and 8966 deletions

View File

@ -140,7 +140,7 @@ jobs:
- name: Install Colossal-AI - name: Install Colossal-AI
run: | run: |
CUDA_EXT=1 pip install -v -e . BUILD_EXT=1 pip install -v -e .
pip install -r requirements/requirements-test.txt pip install -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache - name: Store Colossal-AI Cache

View File

@ -55,7 +55,7 @@ jobs:
if: steps.check-avai.outputs.avai == 'true' if: steps.check-avai.outputs.avai == 'true'
run: | run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
CUDA_EXT=1 pip install -v -e . BUILD_EXT=1 pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt pip install -r requirements/requirements-test.txt

View File

@ -1,4 +1,4 @@
include *.txt README.md include *.txt README.md
recursive-include requirements *.txt recursive-include requirements *.txt
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
recursive-include op_builder *.py recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi

View File

@ -48,6 +48,12 @@ class BaseAccelerator(ABC):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
@abstractmethod
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
@abstractmethod @abstractmethod
def get_current_device(self) -> torch.device: def get_current_device(self) -> torch.device:
""" """
@ -66,6 +72,7 @@ class BaseAccelerator(ABC):
Bind the current process to a device. Bind the current process to a device.
""" """
@abstractmethod
def get_device_name(self, device: Union[torch.device, int]) -> str: def get_device_name(self, device: Union[torch.device, int]) -> str:
""" """
Return the name of the device. Return the name of the device.

View File

@ -24,6 +24,12 @@ class CpuAccelerator(BaseAccelerator):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
return ""
def get_current_device(self) -> torch.device: def get_current_device(self) -> torch.device:
""" """
Return the current device. Return the current device.

View File

@ -21,6 +21,12 @@ class CudaAccelerator(BaseAccelerator):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
return torch.version.cuda
def get_current_device(self) -> torch.device: def get_current_device(self) -> torch.device:
""" """
Return the current device. Return the current device.

View File

@ -27,6 +27,12 @@ class NpuAccelerator(BaseAccelerator):
# ======================= # =======================
# device APIs # device APIs
# ======================= # =======================
def get_version(self) -> str:
"""
Return the version of the accelerator which torch is built against.
"""
return torch.version.npu
def get_current_device(self) -> torch.device: def get_current_device(self) -> torch.device:
""" """
Return the current device. Return the current device.

View File

@ -1,14 +0,0 @@
from .cpu_adam_loader import CPUAdamLoader
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
from .extensions.flash_attention import AttnMaskType
from .flash_attention_loader import ColoAttention, FlashAttentionLoader
__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
"CPUAdamLoader",
"FlashAttentionLoader",
"ColoAttention",
"AttnMaskType",
]

View File

@ -1,28 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, List
from .extensions.base_extension import BaseExtension
class BaseKernelLoader(ABC):
"""
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]):
self._extension_map = extension_map
self._supported_device = supported_device
def run_checks(self):
# run supported device check and other possible checks
pass
@abstractmethod
def fetch_kernel(self):
pass
def load(self):
self.run_checks()
return self.fetch_kernel()

View File

@ -1,64 +0,0 @@
import platform
from collections import OrderedDict
from .base_kernel_loader import BaseKernelLoader
from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension
class CPUAdamLoader(BaseKernelLoader):
"""
CPU Adam Loader
Usage:
# init
cpu_adam = CPUAdamLoader().load()
cpu_adam_op = cpu_adam.CPUAdamOptimizer(
alpha, beta1, beta2, epsilon, weight_decay, adamw_mode,
)
...
# optim step
cpu_adam_op.step(
step, lr, beta1, beta2, epsilon, weight_decay, bias_correction,
params, grads, exp_avg, exp_avg_sq, loss_scale,
)
Args:
func CPUAdamOptimizer:
alpha (float): learning rate. Default to 1e-3.
beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9.
beta2 (float): coefficients used for computing running averages of its square. Default to 0.99.
epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8.
weight_decay (float): weight decay (L2 penalty). Default to 0.
adamw_mode (bool): whether to use the adamw. Default to True.
func step:
step (int): current step.
lr (float): learning rate.
beta1 (float): coefficients used for computing running averages of gradient.
beta2 (float): coefficients used for computing running averages of its square.
epsilon (float): term added to the denominator to improve numerical stability.
weight_decay (float): weight decay (L2 penalty).
bias_correction (bool): whether to use bias correction.
params (torch.Tensor): parameter.
grads (torch.Tensor): gradient.
exp_avg (torch.Tensor): exp average.
exp_avg_sq (torch.Tensor): exp average square.
loss_scale (float): loss scale value.
"""
def __init__(self):
super().__init__(
extension_map=OrderedDict(
arm=ArmCPUAdamExtension,
x86=X86CPUAdamExtension,
),
supported_device=["cpu"],
)
def fetch_kernel(self):
if platform.machine() == "x86_64":
kernel = self._extension_map["x86"]().fetch()
elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]:
kernel = self._extension_map["arm"]().fetch()
else:
raise Exception("not supported")
return kernel

View File

@ -1,63 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "column_remap.cuh"
#include "util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}

View File

@ -1,19 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _column_remap_cuh
#define _column_remap_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);
#endif

View File

@ -1,58 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View File

@ -1,75 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _cuda_buffers_cu
#include "cuda_buffers.cuh"
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;
CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);
cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}
CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}
CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);
g_buffers[_device] = buffers;
}
void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}

View File

@ -1,55 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
const int CUDA_MAX_DEVICES = 16;
// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif
class CudaBuffers
{
public:
int device;
half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8
cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;
CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};
CudaBuffers* get_buffers(const int device_index);
void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
void cleanup_buffers_cuda();
#endif

View File

@ -1,49 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// Workaround for a bug in hipamd, backported from upstream.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Workaround for hipify_python using rocblas instead of hipblas.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
int m,
int n,
int k,
const half* alpha,
const half* AP,
int lda,
const half* BP,
int ldb,
const half* beta,
half* CP,
int ldc) {
return hipblasHgemm(handle, transA, transB, m, n, k,
reinterpret_cast<const hipblasHalf *>(alpha),
reinterpret_cast<const hipblasHalf *>(AP), lda,
reinterpret_cast<const hipblasHalf *>(BP), ldb,
reinterpret_cast<const hipblasHalf *>(beta),
reinterpret_cast<hipblasHalf *>(CP), ldc);
}
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif

View File

@ -1,254 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "q4_matrix.cuh"
#include "q4_matmul.cuh"
#include "column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
// buffer size used for sanity checks
temp_state.numel(),
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}

View File

@ -1,294 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif

View File

@ -1,260 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "util.cuh"
#include "matrix.cuh"
#include "cu_compat.cuh"
#include "cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
typedef void (*fp_q4_matmul_kernel)
(
const half*,
const uint32_t*,
half*,
const half*,
const uint32_t*,
const int,
const int,
const int,
const int,
const int,
const uint32_t*,
bool
);
template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
const half* __restrict__ x,
const uint32_t* __restrict__ w,
half* __restrict__ out,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int dim,
const int width,
const int groupsize,
const int block_size_z,
const uint32_t* __restrict__ x_map,
bool no_zero
)
{
// Start of block
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
int iterations = (x_column_end - x_column) / 8;
// Views
MatrixView_half x_(x, height, dim);
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
MatrixView_q4_column w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
// Zero output
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
{
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
__syncthreads();
}
// Loop over part of x row (and w column)
half2 acc = {};
half acc_h = {};
if constexpr (use_groupsize)
{
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
// could be slightly faster
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
}
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
}
}
// Add to block result
if constexpr (use_half2)
{
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
{
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
}
}
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
// <bool use_half2, bool use_groupsize, bool use_x_map>
if (tuningParams->matmul_no_half2) {
if (block_size_z % groupsize == 0) {
if (x_map) return q4_matmul_kernel<false, true, true >;
else return q4_matmul_kernel<false, true, false>;
} else {
if (x_map) return q4_matmul_kernel<false, false, true >;
else return q4_matmul_kernel<false, false, false>;
}
} else {
if (block_size_z % groupsize == 0)
{
if (x_map) return q4_matmul_kernel<true, true, true >;
else return q4_matmul_kernel<true, true, false>;
} else {
if (x_map) return q4_matmul_kernel<true, false, true >;
else return q4_matmul_kernel<true, false, false>;
}
}
};
// Compute y = x @ w
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero,
cudaStream_t alt_stream
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
uint32_t* x_map = w->cuda_x_map;
const half* x_mapped = x;
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
{
CudaBuffers* buffers = get_buffers(w->device);
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
x_map = NULL;
}
int block_size_z;
if (w->width == 4096) block_size_z = 384; // 7B
else if (w->width == 11008) block_size_z = 256;
else if (w->width == 5120) block_size_z = 384; // 13B
else if (w->width == 13824) block_size_z = 256;
else if (w->width == 6656) block_size_z = 256; // 33B
else if (w->width == 17920) block_size_z = 128;
else block_size_z = 256;
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height + threads.y - 1) / threads.y,
(dim + block_size_z - 1) / block_size_z
);
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
CudaBuffers* buffers = get_buffers(w->device);
const half* x_mapped = x;
if (w->cuda_x_map)
{
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
w->reconstruct(buffers->temp_dq);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
const float alpha = 1.0f;
const float beta = no_zero ? 1.0f : 0.0f;
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
#else
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
#endif
}

View File

@ -1,43 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matmul_cuh
#define _q4_matmul_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include "tuning.h"
// Workaround for hipify_python using rocblas instead of hipblas.
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
#define rocblas_handle hipblasHandle_t
#endif
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero = false,
cudaStream_t alt_stream = NULL
);
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero = false
);
#endif

View File

@ -1,225 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matrix.cuh"
#include <vector>
#include "util.cuh"
#include "matrix.cuh"
using namespace std;
const int UNSHUF_BLOCKSIZE_X = 64;
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
vector<Q4Matrix*> g_q4_matrices;
void g_q4_keep_matrix(Q4Matrix* m)
{
g_q4_matrices.push_back(m);
}
void g_q4_free_matrices()
{
for (const auto& m : g_q4_matrices) delete m;
g_q4_matrices.clear();
}
Q4Matrix::Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
) :
height(_height),
width(_width),
groups(_groups),
device(_device)
{
cudaSetDevice(device);
cuda_qweight = _qweight;
cuda_qzeros = _qzeros;
cuda_scales = _scales;
groupsize = height / groups;
if (_g_idx) make_sequential(_g_idx);
}
Q4Matrix::~Q4Matrix()
{
}
// Make sequential
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint32_t* __restrict__ x_map,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int x_map_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = x_map[x_map_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Move to CUDA
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
height / 8,
1
);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ w,
half* __restrict__ out, // (y)
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int width,
const int groupsize
)
{
// Start of block
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, height / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
// Groupsize version
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
{
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
void Q4Matrix::reconstruct(half* out)
{
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height / 8 + threads.y - 1) / threads.y,
1
);
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}

View File

@ -1,53 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matrix_cuh
#define _q4_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
class Q4Matrix
{
public:
int device;
int height;
int width;
int groups;
int groupsize;
uint32_t* cuda_qweight = NULL;
uint32_t* cuda_qzeros = NULL;
half* cuda_scales = NULL;
uint32_t* cuda_x_map = NULL;
Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
);
~Q4Matrix();
void reconstruct(half* out);
private:
void make_sequential(const uint32_t* cpu_g_idx);
};
void g_q4_keep_matrix(Q4Matrix* m);
void g_q4_free_matrices();
#endif

View File

@ -1,12 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning {
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif

View File

@ -1,33 +0,0 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif

View File

@ -1,191 +0,0 @@
#include "block_reduce.h"
#include "cuda_util.h"
#include "kernels.h"
#include "ls_cub.cuh"
ls::cub::CachingDeviceAllocator g_allocator(true);
template <typename T>
__global__ void ls_cross_entropy_fw_kernel(
const T *__restrict__ inputs, const int *__restrict__ targets,
float *__restrict__ outputs, float *__restrict__ nll_loss_outputs,
const int padding_idx, const float epsilon, const int vocab_size) {
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const int block_start = blockIdx.x * vocab_size;
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) {
if (threadIdx.x == 0) {
nll_loss_outputs[blockIdx.x] = 0.f;
outputs[blockIdx.x] = 0.f;
}
return;
}
for (int i = left_idx; i < right_idx; i += blockDim.x) {
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
}
blockReduce<ReduceType::kMax, 1>(max_input);
__shared__ float s_max_input;
if (threadIdx.x == 0) {
s_max_input = max_input[0];
}
__syncthreads();
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float logit = static_cast<float>(inputs[i]) - s_max_input;
sum_logits[0] += logit;
sum_logits[1] += expf(logit);
}
blockReduce<ReduceType::kSum, 2>(sum_logits);
__shared__ float s_sum_logit;
__shared__ float s_sum_exp;
if (threadIdx.x == 0) {
s_sum_logit = sum_logits[0];
s_sum_exp = sum_logits[1];
}
__syncthreads();
float eps_i = epsilon / (vocab_size - 1);
if (threadIdx.x == 0) {
// neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
float nll_loss = logf(s_sum_exp) -
static_cast<float>(inputs[block_start + target_tid]) +
s_max_input;
nll_loss_outputs[blockIdx.x] = nll_loss;
float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit;
outputs[blockIdx.x] =
(1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss;
}
}
template <typename T>
__global__ void ls_cross_entropy_bw_kernel(
const float *__restrict__ grad_outputs, const T *__restrict__ inputs,
const int *__restrict__ targets, T *__restrict__ grad_inputs,
const int padding_idx, const float epsilon, const int vocab_size) {
/* step1: compute each thread's max_logit and sum_exp_logit, store in
* max_input, sum_exp_logit */
const int block_start = blockIdx.x * vocab_size;
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[1] = {0.f};
const float grad_out = static_cast<float>(grad_outputs[0]);
int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) {
for (int i = left_idx; i < right_idx; i += blockDim.x) {
grad_inputs[i] = 0.f;
}
return;
}
for (int i = left_idx; i < right_idx; i += blockDim.x) {
max_input[0] = fmaxf(max_input[0], static_cast<float>(inputs[i]));
}
blockReduce<ReduceType::kMax, 1>(max_input);
__shared__ float s_max_input;
if (threadIdx.x == 0) {
s_max_input = max_input[0];
}
__syncthreads();
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float logit = static_cast<float>(inputs[i]) - s_max_input;
sum_logits[0] += expf(logit);
}
blockReduce<ReduceType::kSum, 1>(sum_logits);
__shared__ float s_sum_exp;
if (threadIdx.x == 0) {
s_sum_exp = sum_logits[0];
}
__syncthreads();
float eps_i = epsilon / (vocab_size - 1);
float nll_weight = 1.0 - epsilon - eps_i;
for (int i = left_idx; i < right_idx; i += blockDim.x) {
float prob = expf(static_cast<float>(inputs[i]) - s_max_input) / s_sum_exp;
float grad = 0;
grad += (vocab_size * prob - 1) * eps_i;
grad += prob * nll_weight;
if ((i - block_start) == target_tid) {
grad -= nll_weight;
}
grad_inputs[i] = grad_out * grad;
}
}
template <typename T>
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
float *outputs_ptr, float *nll_loss_ptr,
float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size,
const int seq_len, const int vocab_size,
cudaStream_t stream) {
int grid_dim = batch_size * seq_len;
float *nll_loss_buffer = loss_buffer + grid_dim;
ls_cross_entropy_fw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx,
epsilon, vocab_size);
int num_items = grid_dim;
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
loss_buffer, outputs_ptr,
num_items, stream));
CHECK_GPU_ERROR(
g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
loss_buffer, outputs_ptr,
num_items, stream));
CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
nll_loss_buffer, nll_loss_ptr,
num_items, stream));
CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage));
}
template void launch_cross_entropy_fw<float>(
const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template void launch_cross_entropy_fw<__half>(
const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template <typename T>
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
const int *targets_ptr, T *grad_inputs_ptr,
const int padding_idx, const float epsilon,
const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream) {
int grid_dim = batch_size * seq_len;
ls_cross_entropy_bw_kernel<<<grid_dim, MAX_THREADS, 0, stream>>>(
grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx,
epsilon, vocab_size);
}
template void launch_cross_entropy_bw<float>(
const float *grad_outputs_ptr, const float *inputs_ptr,
const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template void launch_cross_entropy_bw<__half>(
const float *grad_outputs_ptr, const __half *inputs_ptr,
const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx,
const float epsilon, const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);

View File

@ -1,88 +0,0 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#include "cublas_wrappers.h"
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const float *A,
const float *B, float *C, cublasGemmAlgo_t algo) {
cublasStatus_t status =
cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha,
(const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k,
(const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n,
(const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const __half *A,
const __half *B, __half *C, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmEx(
handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A,
CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F,
(transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C,
CUDA_R_16F, m, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const float *A, const float *B, float *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmStridedBatchedEx(
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F,
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F,
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C,
batch, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
"error: %d) \n",
batch, m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const __half *A, const __half *B, __half *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch, cublasGemmAlgo_t algo) {
cublasStatus_t status = cublasGemmStridedBatchedEx(
handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F,
(op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F,
(op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C,
batch, CUDA_R_32F, algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m, n, k, (int)status);
return EXIT_FAILURE;
}
return 0;
}

View File

@ -1,169 +0,0 @@
#include <thrust/device_vector.h>
#include <thrust/reduce.h>
#include <thrust/transform_reduce.h>
#include "cuda_util.h"
/* GPU function guard */
std::string _cudaGetErrorString(cudaError_t error) {
return cudaGetErrorString(error);
}
std::string _cudaGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "CUBLAS_UNKNOW";
}
template <typename T>
void check_gpu_error(T result, char const *const func, const char *const file,
int const line) {
if (result) {
throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" +
std::to_string(line) +
"): " + (_cudaGetErrorString(result)) + "\n");
}
}
template void check_gpu_error<cudaError_t>(cudaError_t result,
char const *const func,
const char *const file,
int const line);
template void check_gpu_error<cublasStatus_t>(cublasStatus_t result,
char const *const func,
const char *const file,
int const line);
template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele) {
std::cout << outn << ": ";
std::vector<T> hout(num_output_ele, (T)0);
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T),
cudaMemcpyDeviceToHost);
for (int i = 0; i < num_output_ele; i++) {
std::cout << hout[i] << ", ";
}
std::cout << std::endl;
}
template <>
void print_vec<__half>(const __half *outv, std::string outn,
int num_output_ele) {
std::cout << outn << ": ";
std::vector<__half> hout(num_output_ele, (__half)0.f);
cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half),
cudaMemcpyDeviceToHost);
for (int i = 0; i < num_output_ele; i++) {
std::cout << __half2float(hout[i]) << ", ";
}
std::cout << std::endl;
}
template void print_vec<float>(const float *outv, std::string outn,
int num_output_ele);
template void print_vec<int>(const int *outv, std::string outn,
int num_output_ele);
template void print_vec<__half>(const __half *outv, std::string outn,
int num_output_ele);
template <typename T>
T *cuda_malloc(size_t ele_num) {
size_t byte_size = ele_num * sizeof(T);
T *pdata = nullptr;
CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size));
return pdata;
}
template float *cuda_malloc<float>(size_t ele_num);
template __half *cuda_malloc<__half>(size_t ele_num);
template uint8_t *cuda_malloc<uint8_t>(size_t ele_num);
void cuda_free(void *pdata) {
if (pdata != nullptr) {
cudaFree(pdata);
}
}
template <typename T>
struct _isnan {
__device__ bool operator()(T a) const { return isnan(a); }
};
template <>
struct _isnan<__half> {
__device__ bool operator()(const __half a) const { return __hisnan(a); }
};
template <typename T>
struct _isinf {
__device__ bool operator()(T a) const { return isinf(a); }
};
template <>
struct _isinf<__half> {
__device__ bool operator()(const __half a) const { return __hisinf(a); }
};
template <typename T>
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
std::string file, int line, cudaStream_t stream) {
// check_nan_inf = 0 for checking nan
// check_nan_inf = 1 for checking inf
bool res = false;
std::string msg = file + "(" + std::to_string(line) + "): ";
if (check_nan_inf) {
msg += "nan.";
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
data_ptr + dsize, _isnan<T>(), false,
thrust::logical_or<bool>());
} else {
msg += "inf.";
res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
data_ptr + dsize, _isinf<T>(), false,
thrust::logical_or<bool>());
}
if (res) {
throw std::runtime_error(msg);
}
std::cout << msg << " [check pass]." << std::endl;
}
template void check_nan_inf<float>(const float *data_ptr, int dsize,
bool check_nan_inf, std::string file,
int line, cudaStream_t stream);
template void check_nan_inf<__half>(const __half *data_ptr, int dsize,
bool check_nan_inf, std::string file,
int line, cudaStream_t stream);

File diff suppressed because it is too large Load Diff

View File

@ -1,232 +0,0 @@
#include <cooperative_groups.h>
#include "kernels.h"
namespace cg = cooperative_groups;
/**
@brief: fuse_transpose_bias
Calculate the sum of elements in each column of the matrix.
@thread
gridDim.x = ceil(cols / WARP_SIZE)
blockDim.x = WARP_SIZE
blockDim.y = WARP_SIZE
@param
inp: [rows, cols]
out: [cols]
rows: the number of rows in the matrix
cols: the number of cols in the matrix
*/
template <typename T>
__global__ void column_sum_reduce(const T *__restrict__ inp,
T *__restrict__ out, int rows, int cols) {
__shared__ float tile[WARP_SIZE][WARP_SIZE];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
int y_stride = cols * WARP_SIZE;
float localSum = 0;
// Loop across matrix row
// TODO: optimize to log complexity
if (idx < cols) {
int offset = flat_2dim(threadIdx.y, idx, cols);
for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
// The sum of a row in tile is equal to the sum of a col in original matrix
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
// The change of threadIdx.x is continuous
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
// Calculate the sum of a row in tile
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
if (pos < cols) out[pos] = sum;
}
}
// [r, c] -> [c]
template <>
void launch_fuse_transpose_bias_kernel<float>(const float *inp, float *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<float>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
int rows, int cols,
cudaStream_t stream) {
dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
dim3 block_dim(WARP_SIZE, WARP_SIZE);
column_sum_reduce<__half>
<<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
/**
@brief: fused_add2
Add two matrix inp1 and inp2 to out.
@thread
gridDim.x = batch_size * seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
inp1: [batch_size, seq_len, hidden_dim]
inp2: [batch_size, seq_len, hidden_dim]
out: [batch_size, seq_len, hidden_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
*/
template <typename T>
__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
int hidden_dim);
template <>
__global__ void fused_add2_kernel<float>(float *out, const float *inp1,
const float *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
val.x = vinp1.x + vinp2.x;
val.y = vinp1.y + vinp2.y;
val.z = vinp1.z + vinp2.z;
val.w = vinp1.w + vinp2.w;
out_4[offset + i] = val;
}
}
template <>
__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
const __half *inp2, int hidden_dim) {
int row_id = blockIdx.x;
int offset = flat_2dim(row_id, 0, hidden_dim);
const float4 *inp1_4 = reinterpret_cast<const float4 *>(inp1);
const float4 *inp2_4 = reinterpret_cast<const float4 *>(inp2);
float4 *out_4 = reinterpret_cast<float4 *>(out);
float4 vinp1;
float4 vinp2;
float4 val;
__half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1);
__half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2);
__half2 *h2_val = reinterpret_cast<__half2 *>(&val);
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinp1 = inp1_4[offset + i];
vinp2 = inp2_4[offset + i];
h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]);
h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]);
h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]);
h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]);
out_4[offset + i] = val;
}
}
//[b, s, h] -> [b, s, h]
template <>
void launch_fused_add2<float>(float *out, const float *inp1, const float *inp2,
int batch_size, int seq_len, int hidden_dim,
cudaStream_t &stream) {
hidden_dim >>= 2;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <>
void launch_fused_add2<__half>(__half *out, const __half *inp1,
const __half *inp2, int batch_size, int seq_len,
int hidden_dim, cudaStream_t &stream) {
hidden_dim >>= 3;
dim3 grid_dim(batch_size * seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(out, inp1, inp2,
hidden_dim);
}
template <typename T>
__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output,
int sz0, int sz2, int sz1_1, int sz1_2) {
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x);
if (idx >= nele) {
return;
}
float4 *dst_ptr = (float4 *)output + idx;
int idx2 = idx % sz2;
idx = idx / sz2;
int idx1 = idx % (sz1_1 + sz1_2);
int idx0 = idx / (sz1_1 + sz1_2);
float4 *src_ptr = nullptr;
int sz1 = 0;
if (idx1 < sz1_1) {
sz1 = sz1_1;
src_ptr = (float4 *)inp1;
} else {
idx1 -= sz1_1;
sz1 = sz1_2;
src_ptr = (float4 *)inp2;
}
src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2);
dst_ptr[0] = src_ptr[0];
}
template <>
void launch_concat3_dim1<float>(const float *inp1, const float *inp2,
float *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 2;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}
template <>
void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2,
__half *output, int sz0, int sz2, int sz1_1,
int sz1_2, cudaStream_t stream) {
sz2 >>= 3;
int nele = sz0 * sz2 * (sz1_1 + sz1_2);
int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS;
kernel_concat3_dim1<<<nblock, MAX_THREADS, 0, stream>>>(
inp1, inp2, output, sz0, sz2, sz1_1, sz1_2);
}

View File

@ -1,36 +0,0 @@
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <iostream>
#include <string>
#include "cuda_util.h"
class Context {
public:
Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
}
virtual ~Context() {}
static Context &Instance() {
static Context _ctx;
return _ctx;
}
void set_stream(cudaStream_t stream) {
_stream = stream;
CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream));
}
cudaStream_t get_stream() { return _stream; }
cublasHandle_t get_cublashandle() { return _cublasHandle; }
private:
cudaStream_t _stream;
cublasHandle_t _cublasHandle;
};

View File

@ -1,46 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <type_traits>
#include "cuda_util.h"
template <typename T>
class CrossEntropyLayer {
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer();
void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
float *nll_loss_ptr);
void Backward(const float *grad_outputs_ptr, const T *inputs_ptr,
const int *targets_ptr, T *grad_inputs_ptr);
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private:
void allocate_mem_buffer() {
// allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
}
void free_mem_buffer() {
// free local gpu memory
cuda_free(_loss_buffer);
}
const int _padding_idx;
const float _epsilon;
const int _max_batch_tokens;
size_t _batch_size;
size_t _seq_len;
size_t _vocab_size;
float *_loss_buffer;
};

View File

@ -1,41 +0,0 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#pragma once
#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdio.h>
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const float *A,
const float *B, float *C,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *beta, const __half *A,
const __half *B, __half *C,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
const float *alpha, const float *beta,
const float *A, const float *B, float *C,
cublasOperation_t op_A, cublasOperation_t op_B,
int stride_A, int stride_B, int stride_C,
int batch,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT);
int cublas_strided_batched_gemm(
cublasHandle_t handle, int m, int n, int k, const float *alpha,
const float *beta, const __half *A, const __half *B, __half *C,
cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B,
int stride_C, int batch,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP);

View File

@ -1,34 +0,0 @@
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <math_constants.h>
#include <chrono>
#include <fstream>
#include <iostream>
#include <string>
#include <type_traits>
#include <vector>
template <typename T>
void check_gpu_error(T result, char const *const func, const char *const file,
int const line);
#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__)
template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele);
template <typename T>
T *cuda_malloc(size_t ele_num);
void cuda_free(void *pdata);
template <typename T>
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
std::string file, int line, cudaStream_t stream);
#define CHECK_NAN_INF(ptr, size, stream) \
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))

View File

@ -1,96 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <string>
#include "kernels.h"
template <typename T>
class Dropout {
public:
struct Config {
float ratio;
bool training;
Config(float r) : ratio(r), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
};
Dropout(const Config &config, size_t max_ele_num)
: _config(config), _mask(nullptr) {
_mask = cuda_malloc<uint8_t>(max_ele_num);
}
virtual ~Dropout() { cuda_free(_mask); }
// after attention softmax
void dropout(T *output, const T *input, int count, cudaStream_t stream,
bool bwd = false) {
launch_ls_dropout<T>(output, input, _mask, count, _config.RATIO(), stream,
bwd);
}
void d_dropout(T *d_inp_out, int count, cudaStream_t stream) {
launch_ls_dropout<T>(d_inp_out, d_inp_out, _mask, count, _config.RATIO(),
stream, true);
}
// transformer layer's postprocessing dropout, after attn or ffn module,
// before residual add.
void bias_dropout_residual(T *output, const T *input, const T *residual,
const T *bias, int rows, int cols,
cudaStream_t stream) {
launch_ls_dropout_res_bias<T>(output, input, _mask, bias, residual,
rows * cols, cols, _config.RATIO(), stream);
}
void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output,
int rows, int cols, cudaStream_t stream) {
launch_ls_dropout_bias_bwd<T>(d_input, d_bias, d_output, _mask, rows, cols,
_config.RATIO(), stream);
}
// dropout inside ffn.
void bias_act_dropout(T *output, const T *input, const T *bias, int rows,
int cols, std::string activation_fn,
cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias<ActivationType::kRelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias<ActivationType::kGelu, T>(
output, input, _mask, bias, rows * cols, cols, _config.RATIO(),
stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input,
const T *bias, int rows, int cols,
std::string activation_fn, cudaStream_t stream) {
if (activation_fn == "relu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kRelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else if (activation_fn == "gelu") {
launch_ls_dropout_act_bias_bwd<ActivationType::kGelu, T>(
d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols,
_config.RATIO(), stream);
} else {
throw std::runtime_error("not supported activation: " + activation_fn);
}
}
bool HasDropout() const { return _config.RATIO() > 0.0; }
void SetTrainingMode(bool training) { _config.training = training; }
private:
uint8_t *_mask;
Config _config;
};

View File

@ -1,69 +0,0 @@
#pragma once
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include "cublas_wrappers.h"
#include "kernels.h"
template <typename T>
class FeedForward {
public:
struct Config {
int outputSize;
int inputSize;
std::array<int, 3> gemm_algos;
Config(int outputs, int inputs)
: outputSize(outputs),
inputSize(inputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
};
FeedForward(Config config) : config_(config) {}
~FeedForward() {}
void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
cublasHandle_t &_cublasHandle) {
float alpha = T(1.);
float beta = T(0.);
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
bsz, config_.inputSize, &alpha, &beta, weights, input_ptr,
out, cublasGemmAlgo_t(config_.gemm_algos[0]));
}
void Backward(int bsz, const T *out_grad, const T *input_ptr,
const T *weights, T *weights_grad, T *bias_grad,
cublasHandle_t &_cublasHandle, cudaStream_t &stream,
T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr,
bool compute_bias = true) {
float alpha = (T)1.0, beta = (T)0.0;
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize,
config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad,
weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1]));
cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
bsz, config_.outputSize, &alpha, &beta, weights, out_grad,
inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2]));
if (compute_bias) {
launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz,
config_.outputSize, stream);
}
}
void reset_size(int outputSize, int inputSize) {
config_.outputSize = outputSize;
config_.inputSize = inputSize;
}
private:
Config config_;
};

View File

@ -1,275 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdexcept>
#define MAX_THREADS 1024
#define WARP_SIZE 32
enum class ActivationType { kRelu, kGelu };
void launch_curand_init(int total_count, int dim, cudaStream_t stream);
template <typename T>
void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp,
const T *scale, const T *bias, int batch_size,
int hidden_dim, cudaStream_t stream);
template <typename T>
void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out, const T *gamma,
const T *betta, const T *vars, const T *means, int batch,
int hidden_dim, cudaStream_t stream[2]);
template <typename T>
void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads,
int from_len, int to_len, bool mask_future,
cudaStream_t stream);
template <typename T>
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
int softmax_len, cudaStream_t stream);
// [b, s, h] -> [b, nh, s, ad]
template <typename T>
void launch_transform_0213(T *output, const T *vals, int batch_size,
int seq_length, int hidden_dim, int nhead,
cudaStream_t stream);
// [b, s, 3, h] -> [3, b, nh, s, ad]
template <typename T>
void launch_bias_add_transform_20314(T *output, const T *input, const T *bias,
int dim_0, int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream);
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template <typename T>
void launch_transform4d_0213(T *output, const T *vals, int batch_size,
int seq_len, int hidden_dim, int nhead,
int trans_count, cudaStream_t stream);
template <typename T>
void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count,
float ratio, cudaStream_t stream, bool backward = false);
template <typename T>
void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask,
const T *bias, const T *residual,
int total_count, int dim, float ratio,
cudaStream_t stream);
template <ActivationType, typename T>
void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask,
const T *bias, int total_count, int dim,
float ratio, cudaStream_t stream);
template <typename T>
void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template <ActivationType act_type, typename T>
void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
const T *bias, const T *out_grad,
const uint8_t *mask, int row_size, int dim,
float ratio, cudaStream_t stream);
template <typename T>
void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols,
cudaStream_t stream);
void launch_param_update(const float *input, __half *output, int size,
cudaStream_t stream);
template <typename T>
void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0,
int sz2, int sz1_1, int sz1_2, cudaStream_t stream);
template <typename T>
void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size,
int seq_len, int hidden_size, cudaStream_t &stream);
template <typename T>
void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
float *outputs_ptr, float *nll_loss_ptr,
float *loss_buffer, const int padding_idx,
const float epsilon, const int batch_size,
const int seq_len, const int vocab_size,
cudaStream_t stream);
template <typename T>
void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
const int *targets_ptr, T *grad_inputs_ptr,
const int padding_idx, const float epsilon,
const int batch_size, const int seq_len,
const int vocab_size, cudaStream_t stream);
template <typename T>
void launch_lookup_scale_pos_dropout(
T *output, const int *input, const T *embeddings, const T *pos_embeddings,
uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
int padding_idx, float dropout_ratio, int step, cudaStream_t &stream);
template <typename T>
void launch_d_lookup_scale_pos_dropout(
T *grad_embeddings, const T *grad_output, const int *input,
const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim,
int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream);
/* Convert 2-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) {
return id1 * dim2 + id2;
}
/* Convert 3-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
int dim2, int dim3) {
return id1 * dim2 * dim3 + id2 * dim3 + id3;
}
/* Convert 4-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
int id4, int dim2, int dim3,
int dim4) {
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
int res = id4;
int ld = dim4;
res += id3 * ld;
ld *= dim3;
res += id2 * ld;
ld *= dim2;
res += id1 * ld;
return res;
}
/* Convert 5-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3,
int id4, int id5, int dim2,
int dim3, int dim4,
int dim5) {
// return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) +
// id4*dim5 + dim5;
int res = id5;
int ld = dim5;
res += id4 * ld;
ld *= dim4;
res += id3 * ld;
ld *= dim3;
res += id2 * ld;
ld *= dim2;
res += id1 * ld;
return res;
}
/* Convert 6-dim tensor index into vector index */
__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
int id4, int id5, int id6,
int dim2, int dim3, int dim4,
int dim5, int dim6) {
// return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) +
// id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6;
int res = id6;
int ld = dim6;
res += id5 * ld;
ld *= dim5;
res += id4 * ld;
ld *= dim4;
res += id3 * ld;
ld *= dim3;
res += id2 * ld;
ld *= dim2;
res += id1 * ld;
return res;
}
/* Convert vector index to 6-dim tensor index */
__forceinline__ __host__ __device__ void decompose_6dim(
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
int *id1, int *id2, int *id3, int *id4, int *id5) {
*id5 = src % dim5;
src /= dim5;
*id4 = src % dim4;
src /= dim4;
*id3 = src % dim3;
src /= dim3;
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 5-dim tensor index */
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
int dim2, int dim3,
int dim4, int *id0,
int *id1, int *id2,
int *id3, int *id4) {
*id4 = src % dim4;
src /= dim4;
*id3 = src % dim3;
src /= dim3;
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 4-dim tensor index */
__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
int dim2, int dim3,
int *id0, int *id1,
int *id2, int *id3) {
*id3 = src % dim3;
src /= dim3;
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 3-dim tensor index */
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
int dim2, int *id0,
int *id1, int *id2) {
*id2 = src % dim2;
src /= dim2;
*id1 = src % dim1;
*id0 = src / dim1;
}
/* Convert vector index to 2-dim tensor index */
__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1,
int *id0, int *id1) {
*id1 = src % dim1;
*id0 = src / dim1;
}

View File

@ -1,12 +0,0 @@
// copied from https://github.com/dmlc/dgl/pull/2758
#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define CUB_NS_PREFIX namespace ls {
#define CUB_NS_POSTFIX }
#include "cub/cub.cuh"
#include "cub/util_allocator.cuh"
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
#endif

View File

@ -1,65 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "kernels.h"
using namespace std;
template <typename T>
class Normalize_Layer {
public:
struct Config {
uint32_t hidden_dim;
bool use_mean;
Config(uint32_t hidden_dim, bool use_mean = false)
: hidden_dim(hidden_dim), use_mean(use_mean) {}
};
Normalize_Layer(Config config, size_t max_rows)
: config_(config), vars_(nullptr), means_(nullptr) {
vars_ = cuda_malloc<T>(max_rows);
if (config_.use_mean) {
means_ = cuda_malloc<T>(max_rows);
}
}
~Normalize_Layer() {
cuda_free(vars_);
cuda_free(means_);
}
void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta,
int batch_size, cudaStream_t stream) {
launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size,
config_.hidden_dim, stream);
}
/*
residual_grad, inp_or_out, betta should be treated carefully.
inp_or_out = input if use_mean else output
residual_grad, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
betta are only used to compute xhat,
(use_mean == false) ^ (betta == nullptr) should be true
*/
void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out, const T *gamma,
const T *betta, int batch_size, cudaStream_t stream[2]) {
launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad,
inp_or_out, gamma, betta, vars_, means_, batch_size,
config_.hidden_dim, stream);
}
inline bool use_mean() const { return config_.use_mean; }
private:
Config config_;
T *vars_;
T *means_;
};

View File

@ -1,42 +0,0 @@
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <fstream>
#include "kernels.h"
using namespace std;
template <typename T>
class Softmax {
public:
struct Config {
size_t nhead;
Config(size_t nhead) : nhead(nhead) {}
};
Softmax(Config config) : config_(config) {}
~Softmax() {}
void Forward(T *vals, const T *attn_mask, int batch_size, int from_len,
int to_len, cudaStream_t &stream, bool mask_future = true) {
launch_attn_softmax<T>(vals, attn_mask, batch_size, config_.nhead, from_len,
to_len, mask_future, stream);
}
void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len,
int to_len, cudaStream_t stream) {
launch_attn_softmax_bw<T>(out_grad, soft_out,
batch_size * config_.nhead * from_len, to_len,
stream);
}
void reset_size(size_t nhead) { config_.nhead = nhead; }
private:
Config config_;
};

View File

@ -1,100 +0,0 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include "cublas_wrappers.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int m;
int n;
int k;
float alpha;
float beta;
cublasOperation_t op_A;
cublasOperation_t op_B;
std::array<int, 3> gemm_algos;
Config(float param_alpha, float param_beta, cublasOperation_t opA,
cublasOperation_t opB)
: alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
void SetConfig(int mm, int nn, int kk) {
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config &config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
cublasHandle_t handle) {
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(
handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
_buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
}
void Backward(int bsz, const T *d_output, const T *_buffer_a,
const T *_buffer_b, cublasHandle_t handle,
T *inpGradA = nullptr, T *inpGradB = nullptr) {
int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
cublasOperation_t op_b =
(_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
// Calculate d_A.
cublas_strided_batched_gemm(
handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
(_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
(_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
cublasGemmAlgo_t(_config.gemm_algos[1]));
// A need to transpose.
cublasOperation_t op_a =
(_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(
handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
_buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
}
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
};

File diff suppressed because it is too large Load Diff

View File

@ -1,365 +0,0 @@
#include <cooperative_groups.h>
#include <math.h>
#include <cub/block/block_load.cuh>
#include <cub/cub.cuh>
#include "block_reduce.h"
#include "kernels.h"
namespace cg = cooperative_groups;
const float EPSILON = 1e-8f;
/**
@brief: softmax_kernel
Softmax forward kernel for
enc-self-attn, dec-self-attn, encdec-attn
@thread
gridDim.x = dynamic
gridDim.y = batch_size
gridDim.z = nhead
blockDim.x = from_len
@param
inp: [batch_size, nhead, from_len, to_len], softmax input.
attn_mask: [batch_size, to_len], padding tokens are -inf,
non padding tokens are 0.
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// block reduce max
blockReduce<ReduceType::kMax, token_per_reduce>(l_max);
// write shared
__shared__ float s_max[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_max[i] = l_max[i];
}
}
__syncthreads();
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - s_max[i]);
l_sum[i] += val[i][j];
}
}
// block reduce sum
blockReduce<ReduceType::kSum, token_per_reduce>(l_sum);
// write shared
__shared__ float s_sum[token_per_reduce];
if (threadIdx.x == 0) {
for (int i = 0; i < token_per_reduce; i++) {
s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
}
}
__syncthreads();
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * s_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
template <typename T, int block_dim, int ele_per_thread>
__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
int to_len, bool mask_future) {
int batch_id = blockIdx.y;
int head_id = blockIdx.z;
const int nhead = gridDim.z;
const int token_per_reduce = 1;
typedef cub::BlockLoad<T, block_dim, ele_per_thread,
cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_dim, ele_per_thread,
cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
T mval[ele_per_thread];
if (attn_mask) {
attn_mask += batch_id * to_len;
BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG);
}
inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len);
for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len;
token_id += gridDim.x * token_per_reduce) {
T inp_val[token_per_reduce][ele_per_thread];
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len,
REDUCE_FLOAT_INF_NEG);
}
/* step 1. compute max */
// thread local max
float val[token_per_reduce][ele_per_thread];
float l_max[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_max[i] = REDUCE_FLOAT_INF_NEG;
for (int j = 0; j < ele_per_thread; j++) {
if (attn_mask) {
val[i][j] = (float)inp_val[i][j] + (float)mval[j];
} else {
if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) {
val[i][j] = REDUCE_FLOAT_INF_NEG;
} else {
val[i][j] = (float)inp_val[i][j];
}
}
l_max[i] = fmaxf(l_max[i], val[i][j]);
}
}
// warp reduce max
warpReduce<ReduceType::kMax, token_per_reduce>(l_max);
/* step 2. compute sum */
// thread local sum
float l_sum[token_per_reduce];
for (int i = 0; i < token_per_reduce; i++) {
l_sum[i] = 0.f;
for (int j = 0; j < ele_per_thread; j++) {
val[i][j] = __expf(val[i][j] - l_max[i]);
l_sum[i] += val[i][j];
}
}
// warp reduce sum
warpReduce<ReduceType::kSum, token_per_reduce>(l_sum);
/* step 3. compute final result */
for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) {
l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON);
for (int j = 0; j < ele_per_thread; j++) {
inp_val[i][j] = (T)(val[i][j] * l_sum[i]);
}
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len);
}
} // blockIdx.x
}
/*
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template <>
void launch_attn_softmax<float>(float *inp, const float *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<float, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<float, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 16;
ker_attn_softmax<float, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 32;
ker_attn_softmax<float, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 64;
ker_attn_softmax<float, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
template <>
void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask,
int batch_size, int nhead, int from_len,
int to_len, bool mask_future,
cudaStream_t stream) {
dim3 grid_dim(1, batch_size, nhead);
if (to_len <= 32) {
ker_attn_softmax_lt32<__half, 32, 1><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 64) {
ker_attn_softmax_lt32<__half, 32, 2><<<grid_dim, 32, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 128) {
grid_dim.x = 8;
ker_attn_softmax<__half, 64, 2><<<grid_dim, 64, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 256) {
grid_dim.x = 16;
ker_attn_softmax<__half, 128, 2><<<grid_dim, 128, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else if (to_len <= 512) {
grid_dim.x = 32;
ker_attn_softmax<__half, 256, 2><<<grid_dim, 256, 0, stream>>>(
inp, attn_mask, from_len, to_len, mask_future);
} else {
throw std::runtime_error(
"Sequence length greater than 512 is currently not supported");
}
}
/**
@brief: ker_attn_softmax_bw
Softmax backward in self attention.
@thread
gridDim.x = batch_size * nhead * seq_len / warps_per_block
blockDim.x = WARP_SIZE
blockDim.y = warps_per_block
@param
grad: [batch_size, nhead, seq_len, seq_len], output grad.
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
*/
template <typename T, int ITERATIONS>
__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
inp += offset;
T grad_reg[ITERATIONS];
T inp_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
inp_reg[i] = inp[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)inp_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum));
}
}
template <typename T>
void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows,
int softmax_len, cudaStream_t stream) {
const int warps_per_block = 4;
// rows = batch_size * nhead * from_len
dim3 grid_dim(rows / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (softmax_len <= 32)
ker_attn_softmax_bw<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 64)
ker_attn_softmax_bw<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 128)
ker_attn_softmax_bw<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 256)
ker_attn_softmax_bw<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 384)
ker_attn_softmax_bw<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 512)
ker_attn_softmax_bw<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 768)
ker_attn_softmax_bw<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 1024)
ker_attn_softmax_bw<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else if (softmax_len <= 2048)
ker_attn_softmax_bw<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, softmax_len);
else
throw std::runtime_error(
std::string(
"Special sequence length found in softmax backward, seq_len: ") +
std::to_string(softmax_len));
}
template void launch_attn_softmax_bw<__half>(__half *out_grad,
const __half *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);
template void launch_attn_softmax_bw<float>(float *out_grad,
const float *soft_inp, int rows,
int softmax_len,
cudaStream_t stream);

View File

@ -1,314 +0,0 @@
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include "kernels.h"
using namespace cub;
/**
@brief: transform_0213
Split the attention heads and reshape input
during backward progress of encoder self-attention
@thread
gridDim.x = batch_size
gridDim.y = seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
input: [batch_size, seq_len, hidden_dim]
output: [batch_size, nhead, seq_len, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
*/
template <typename T>
__global__ void transform_0213(T *output, const T *input, int hidden_dim,
int head_dim);
template <>
__global__ void transform_0213<float>(float *output, const float *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
template <>
__global__ void transform_0213<__half>(__half *output, const __half *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
// [b, s, h] -> [b, nh, s, ad]
template <>
void launch_transform_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
template <>
void launch_transform_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
/**
@brief: bias_add_transform_20314
Add bias to input, transform from
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
@thread
gridDim.x = dim_0
gridDim.y = dim_1
gridDim.z = dim_2
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
@param
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
bias: [dim_2, dim_3, dim_4]
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
*/
template <typename T>
__global__ void bias_add_transform_20314(T *output, const T *input,
const T *bias, int dim_3, int dim_4);
template <>
__global__ void bias_add_transform_20314<float>(float *output,
const float *input,
const float *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
vres4.x = vqkv4.x + vbias4.x;
vres4.y = vqkv4.y + vbias4.y;
vres4.z = vqkv4.z + vbias4.z;
vres4.w = vqkv4.w + vbias4.w;
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
template <>
__global__ void bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_3,
int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
__half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4);
__half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4);
__half2 *h2_res = reinterpret_cast<__half2 *>(&vres4);
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]);
h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]);
h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]);
h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]);
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
// [b, s, 3, h] -> [3, b, nh, s, ad]
template <>
void launch_bias_add_transform_20314<float>(float *output, const float *input,
const float *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 2;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
template <>
void launch_bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 3;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
/**
@brief: transform4d_0213
Reshape the input matrix to merge the heads
@thread
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
blockDim.x = max_block_thread
@param
input: [trans_count, batch_size, nhead, seq_len, head_dim]
output: [batch_size, seq_len, trans_count, nhead, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
trans_count: 1 or 3, the count of matrice need to be transformed
*/
template <typename T>
__global__ void transform4d_0213(T *output, const T *input, int batch_size,
int seq_len, int trans_count, int nhead,
int head_dim, int num_all) {
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset >= num_all) {
return;
}
int trans_id, batch_id, head_id, token_id, dim_id;
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
&batch_id, &head_id, &token_id, &dim_id);
// [b, s, tc, nh, ad]
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
seq_len, trans_count, nhead, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
res4[trg_offset] = input4[offset];
}
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template <>
void launch_transform4d_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}
template <>
void launch_transform4d_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len,
int hidden_dim, int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}

View File

@ -1,406 +0,0 @@
#include "multihead_attention_1d.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/Types.hpp>
#else
#include <c10d/Types.hpp>
#endif
#include <iostream>
#include "context.h"
#include "kernels.h"
template <typename T>
MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_size,
int num_heads,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm)
: _layer_id(layer_id),
_max_batch_tokens(max_batch_tokens),
_max_seq_len(max_seq_len),
_hidden_size(hidden_size),
_heads(num_heads),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_qkv_linear(
typename FeedForward<T>::Config(3 * hidden_size, hidden_size)),
_attn_out_linear(
typename FeedForward<T>::Config(hidden_size, hidden_size)),
_attn_ln(typename Normalize_Layer<T>::Config(hidden_size, false),
_max_batch_tokens),
_softmax(typename Softmax<T>::Config(num_heads)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio),
_max_batch_tokens * _heads * _max_seq_len),
_attn_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio),
_max_batch_tokens * _hidden_size),
_attn_scores(typename StridedBatchGemm<T>::Config(
(T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T,
CUBLAS_OP_N)),
_attn_context(typename StridedBatchGemm<T>::Config(
T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) {
assert(_hidden_size % _heads == 0);
}
template <typename T>
MultiHeadAttention<T>::~MultiHeadAttention() {
free_mem_buffer();
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr,
const T *input_mask_ptr,
T *output_ptr, T *buffer) {
T *q_tf_ptr = _qkv_ptr;
T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
if (_pre_or_postLayerNorm) {
_attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer,
_cublasHandle);
launch_bias_add_transform_20314<T>(q_tf_ptr, buffer, _attn_qkvb_ptr,
_batch_size, _seq_len, 3, _heads / pg_size,
_hidden_size / _heads, _stream);
// attention scores, q*k
_attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle);
// Softmax + Mask
_softmax.reset_size(_heads / pg_size);
_softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len,
_seq_len, _stream, true);
// attn prob dropout.
_attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
// attention context, score * v
_attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr,
_cublasHandle);
// [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213<T>(_attn_o_inp_ptr, buffer, _batch_size, _seq_len,
_hidden_size / pg_size, _heads / pg_size, 1,
_stream);
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr,
output_ptr, _cublasHandle);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
} else {
auto data_type = torch::kFloat;
if (typeid(T) != typeid(float)) {
data_type = torch::kHalf;
}
auto output_tensor = torch::from_blob(
output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {output_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait();
}
_attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr,
_attn_ob_ptr, _batch_tokens, _hidden_size,
_stream);
if (!_pre_or_postLayerNorm) {
// in-place ln since ln-input will not be used in post-ln mode
_attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr,
_batch_tokens, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
T *out_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim
attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer);
}
template <typename T>
void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr,
const T *input_mask_ptr,
const T *output_ptr,
const T *grad_output_ptr,
T *grad_input_ptr, T *buffer) {
cudaStream_t streams[2] = {_stream, _stream};
const T *q_tf_ptr = _qkv_ptr;
const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size;
const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size;
// batch_dim = batch_size * seq_len * hidden_size
// buffer size: batch_dim * 3 + max(batch_dim * 3,
// batch_size * head_num * seq_len * seq_len)
T *grad_residual_ptr = buffer;
buffer += _batch_dim;
T *grad_input_buf_ptr = buffer; // batch_dim
T *grad_qkv_5d_ptr = buffer; // batch_dim * 3
buffer += 3 * _batch_dim / pg_size;
T *grad_qkv_4d_ptr = buffer; // batch_dim * 3
T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len
// buffer += max(3 * _batch_dim,
// batch_size * head_num * seq_len * seq_len);
if (_pre_or_postLayerNorm) {
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_output_ptr, _batch_tokens,
_hidden_size, _stream);
} else {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr,
grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr,
_attn_nb_ptr, _batch_tokens, streams);
_attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr,
grad_residual_ptr, _batch_tokens,
_hidden_size, _stream);
}
// bw of output project
_attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size);
_attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr,
_attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
false);
launch_transform_0213<T>(grad_input_ptr, grad_input_buf_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
_stream);
// bw of score * v
_attn_context.Backward(
_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle,
grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr);
_attn_prob_dropout.d_dropout(grad_softmax_ptr,
_batch_heads * _seq_len * _seq_len, _stream);
_softmax.reset_size(_heads / pg_size);
_softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len,
_seq_len, _stream);
// bw of q * k
_attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr,
_cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size,
grad_qkv_5d_ptr);
// [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213<T>(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size,
_seq_len, _hidden_size / pg_size, _heads / pg_size,
3, _stream);
const T *gemmQKV_inp_ptr =
_pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr;
_qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size);
_qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr,
_attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr,
_cublasHandle, _stream, grad_input_buf_ptr, nullptr,
true);
// allreduce
if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) {
} else {
auto data_type = torch::kFloat;
if (typeid(T) != typeid(float)) {
data_type = torch::kHalf;
}
auto grad_input_tensor =
torch::from_blob(grad_input_buf_ptr,
{int(_batch_size), int(_seq_len), int(_hidden_size)},
torch::TensorOptions(torch::kCUDA).dtype(data_type));
std::vector<torch::Tensor> allreduce_tensors = {grad_input_tensor};
auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions());
work->wait();
}
if (_pre_or_postLayerNorm) {
_attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr,
grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr,
_attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams);
} else {
// FIXME later
launch_fused_add2<T>(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr,
_batch_size, _seq_len, _hidden_size, _stream);
}
}
template <typename T>
void MultiHeadAttention<T>::Backward(const T *grad_output_ptr,
const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr,
T *grad_input_ptr) {
_stream = Context::Instance().get_stream();
_cublasHandle = Context::Instance().get_cublashandle();
T *buffer = _shared_mem_ptr;
/*
buffer size needed by attn bw:
4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len);
*/
attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr,
grad_input_ptr, buffer);
}
template <typename T>
void MultiHeadAttention<T>::SetTrainingMode(bool training) {
// Dropout will be skipped when not in training model.
_attn_prob_dropout.SetTrainingMode(training);
_attn_dropout.SetTrainingMode(training);
}
template <typename T>
T *MultiHeadAttention<T>::_shared_mem_ptr = nullptr;
template class MultiHeadAttention<float>;
template class MultiHeadAttention<__half>;
// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
static std::unordered_map<int, std::shared_ptr<void>> s_multihead_attention;
template <typename T>
int create_multihead_attention(int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_dim, int num_heads,
float attn_prob_dropout_ratio,
float hidden_dropout_ratio,
bool pre_or_postLayerNorm,
c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Context::Instance().set_stream(stream);
auto layer = std::make_shared<MultiHeadAttention<T>>(
layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads,
attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm);
layer->SetPG(pg_);
s_multihead_attention[layer_id] = layer;
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
return 0;
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_fw(
int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask,
const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias,
const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias,
const torch::Tensor &norm_weight, const torch::Tensor &norm_bias,
bool training_mode, bool prelayernorm) {
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
const T *input_ptr = (const T *)input.data_ptr();
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
auto output = torch::empty_like(input);
T *out_ptr = (T *)output.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(input.size(0), input.size(1));
layer->SetTrainingMode(training_mode);
layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr();
layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr();
layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr();
layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr();
layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr();
layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr();
layer->Forward(input_ptr, input_mask_ptr, out_ptr);
return {output};
}
template <typename T>
std::vector<torch::Tensor> multihead_attention_bw(
int layer_id, const torch::Tensor &grad_dec_output,
const torch::Tensor &output, const torch::Tensor &input,
const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight,
const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight,
const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight,
const torch::Tensor &norm_bias) {
auto g_output = grad_dec_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
auto grad_input = torch::empty_like(input);
auto grad_in_proj_weight = torch::empty_like(in_proj_weight);
auto grad_in_proj_bias = torch::empty_like(in_proj_bias);
auto grad_out_proj_weight = torch::empty_like(out_proj_weight);
auto grad_out_proj_bias = torch::empty_like(out_proj_bias);
auto grad_norm_weight = torch::empty_like(norm_weight);
auto grad_norm_bias = torch::empty_like(norm_bias);
// inputs.
const T *grad_dec_output_ptr = (const T *)g_output.data_ptr();
const T *input_ptr = (const T *)input.data_ptr();
const T *output_ptr = (const T *)output.data_ptr();
const T *input_mask_ptr = (const T *)input_mask.data_ptr();
// outputs.
T *grad_input_ptr = (T *)grad_input.data_ptr();
std::shared_ptr<MultiHeadAttention<T>> layer =
std::static_pointer_cast<MultiHeadAttention<T>>(
s_multihead_attention[layer_id]);
layer->set_cur_batch_shape(g_output.size(0), g_output.size(1));
layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr();
layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr();
layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr();
layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr();
layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr();
layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr();
layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr,
grad_input_ptr);
return {grad_input, grad_in_proj_weight, grad_in_proj_bias,
grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight,
grad_norm_bias};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multihead_attention_fw_fp32", &multihead_attention_fw<float>,
"Multi-head Attention forward with fp32 (CUDA)");
m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>,
"Multi-head Attention forward with fp16 (CUDA)");
m.def("multihead_attention_bw_fp32", &multihead_attention_bw<float>,
"Multi-head Attention backward with fp32 (CUDA)");
m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>,
"Multi-head Attention backward with fp16 (CUDA)");
m.def("create_multihead_attention_fp32", &create_multihead_attention<float>,
"Create Multi-head Attention with fp32 (CUDA)");
m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>,
"Create Multi-head Attention with fp16 (CUDA)");
}

View File

@ -1,167 +0,0 @@
#pragma once
#include <c10/util/intrusive_ptr.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp>
#endif
#include <string>
#include <type_traits>
#include "cuda_util.h"
#include "dropout.h"
#include "feed_forward.h"
#include "normalize_layer.h"
#include "softmax.h"
#include "strided_batch_gemm.h"
template <typename T>
class MultiHeadAttention {
public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
int hidden_size, int num_heads, float attn_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm);
virtual ~MultiHeadAttention();
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr,
const T *output_ptr, const T *input_mask_ptr,
T *grad_input_ptr);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
const T *output_ptr, const T *grad_output_ptr,
T *grad_input_attn_layer_bwptr, T *buffer);
void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size;
_seq_len = seq_len;
_batch_tokens = batch_size * seq_len;
_batch_heads = batch_size * _heads / pg_size;
_batch_dim = _batch_tokens * _hidden_size;
_attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len);
}
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
void SetPG(c10::intrusive_ptr<c10d::ProcessGroup> pg_) {
pg = pg_;
pg_size = 1;
if (pg != c10::detail::UniqueVoidPtr()) {
pg_size = pg->getSize();
}
allocate_mem_buffer();
}
// weights ptr
const T *_attn_qkvw_ptr;
const T *_attn_qkvb_ptr;
const T *_attn_ow_ptr;
const T *_attn_ob_ptr;
const T *_attn_nw_ptr;
const T *_attn_nb_ptr;
// grads ptr
T *_grad_attn_qkvw_ptr;
T *_grad_attn_qkvb_ptr;
T *_grad_attn_ow_ptr;
T *_grad_attn_ob_ptr;
T *_grad_attn_nw_ptr;
T *_grad_attn_nb_ptr;
private:
void allocate_mem_buffer() {
// allocate local gpu memory
if (_pre_or_postLayerNorm) {
_gemmQKV_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
} else {
_gemmQKV_inp_ptr = nullptr;
}
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
// buffer size needed by attn bw
size_t smem_size =
4 * _max_batch_tokens * _hidden_size / pg_size +
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr);
_shared_mem_ptr = cuda_malloc<T>(smem_size);
}
}
void free_mem_buffer() {
// free local gpu memory
cuda_free(_gemmQKV_inp_ptr);
cuda_free(_qkv_ptr);
cuda_free(_soft_out_ptr);
cuda_free(_ctx_bufB_ptr);
cuda_free(_attn_o_inp_ptr);
// free shared gpu memory between layers
cuda_free(_shared_mem_ptr);
_shared_mem_ptr = nullptr;
}
// const parameter between batch
const size_t _layer_id;
const size_t _hidden_size;
const size_t _heads;
const size_t _max_batch_tokens;
const size_t _max_seq_len;
const bool _pre_or_postLayerNorm;
// dynamic parameter between batch
size_t _batch_size;
size_t _seq_len;
size_t _batch_tokens;
size_t _batch_heads;
size_t _batch_dim;
bool _training;
cublasHandle_t _cublasHandle;
cudaStream_t _stream;
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _attn_ln;
Softmax<T> _softmax;
Dropout<T> _attn_prob_dropout;
Dropout<T> _attn_dropout;
StridedBatchGemm<T> _attn_scores;
StridedBatchGemm<T> _attn_context;
// local GPU memory
T *_gemmQKV_inp_ptr;
T *_qkv_ptr;
T *_soft_out_ptr;
T *_ctx_bufB_ptr;
T *_attn_o_inp_ptr;
// shared GPU memory between layer
static T *_shared_mem_ptr;
c10::intrusive_ptr<c10d::ProcessGroup> pg;
int pg_size;
};

View File

@ -1,8 +0,0 @@
#include <torch/extension.h>
#include "linear.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32,
"Linear SiLU (INT8)");
}

View File

@ -1,162 +0,0 @@
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
#include "linear.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
) {
auto M = input.size(0);
auto N = weight.size(0);
auto K = input.size(1);
using ElementOutput = float;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = float;
using ElementInputA = int8_t; // <- data type of elements in input matrix A
using ElementInputB = int8_t; // <- data type of elements in input matrix B
// The code section below describes matrix layout of input and output
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
// for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
#if CUDA_ARCH >= 800
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;
#elif CUDA_ARCH >= 750
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<
ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue // <- data type for alpha in linear combination
// function
>;
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
EpilogueOp>;
#elif CUDA_ARCH >= 700
#define USE_TORCH_SILU
using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration<
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>;
using Gemm = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor,
ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
cutlass::arch::OpClassSimt, cutlass::arch::Sm70,
DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape,
DefaultGemmCfg::InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>;
#else
#error "Unsupported cuda arch"
#endif
auto input_size = cutlass::MatrixCoord(M, K);
auto weight_size = cutlass::MatrixCoord(K, N);
auto output_size = cutlass::MatrixCoord(M, N);
auto device = input.device();
// use the broadcasted bias as the output
auto out = bias.to(device).view({1, -1}).repeat({M, 1});
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
input.data_ptr<ElementInputA>(), LayoutInputA::packed(input_size));
cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
out.data_ptr<ElementOutput>(), LayoutOutput::packed(output_size));
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
input_ref, // <- reference to matrix A on device
weight_ref, // <- reference to matrix B on device
out_ref, // <- reference to matrix C on device
out_ref, // <- reference to matrix D on device
{alpha, beta}, 1};
Gemm gemm_op;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot implement");
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot run");
}
#ifdef USE_TORCH_SILU
#undef USE_TORCH_SILU
out = torch::silu(out);
#endif
return out;
}

View File

@ -1,12 +0,0 @@
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <iostream>
torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8
torch::Tensor weight, // INT8
torch::Tensor bias, // FP32
float alpha, // FP32
float beta // FP32
);

View File

@ -1,338 +0,0 @@
import math
from dataclasses import dataclass
import torch
from torch import nn
from torch.autograd import Function
def check_config(config):
if config.hidden_size % config.nhead != 0:
raise Exception("hidden_size % nhead != 0")
factor = 8 if config.fp16 else 4
upbound = factor * 1024 * 4
if config.hidden_size > upbound:
# as required by ln backward kernel currently
raise Exception(f"hidden_size > {upbound}")
head_dim = config.hidden_size // config.nhead
if head_dim % factor != 0:
# as required by reshape kernel
raise Exception(f"head_dim({head_dim}) % {factor} != 0")
def calc_offset(sizes):
offsets = [0]
tmp = 0
for x in sizes:
tmp += x
offsets.append(tmp)
return offsets
colossal_multihead_attention = None
@dataclass
class Config:
max_batch_tokens: int # max batch token numbers
max_seq_len: int # max sequence length
hidden_size: int # size of transformer hidden layers
nhead: int # number of heads in attention
attn_prob_dropout_ratio: float # attention score dropout ratio
hidden_dropout_ratio: float # dropout ration before residual
norm_first: bool # norm_first
fp16: bool # fp16 precision
class MultiHeadAttention1DFunc(Function):
@staticmethod
def forward(
ctx,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config,
):
cuda_module = colossal_multihead_attention
forward_func = (
cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32
)
if config.fp16:
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(output,) = forward_func(
config.layer_id,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
config.training,
config.norm_first,
)
if config.is_grad_enabled and config.training:
ctx.save_for_backward(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
ctx.config = config
return output
@staticmethod
def backward(ctx, grad_output):
assert ctx.config.training
cuda_module = colossal_multihead_attention
backward_func = (
cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32
)
(
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
) = ctx.saved_tensors
grad_input = None
grad_in_proj_weight = None
grad_in_proj_bias = None
grad_out_proj_weight = None
grad_out_proj_bias = None
grad_norm_weight = None
grad_norm_bias = None
if ctx.config.fp16:
grad_output = grad_output.to(torch.half)
output = output.to(torch.half)
input = input.to(torch.half)
input_mask = input_mask.to(torch.half)
(
grad_input,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
) = backward_func(
ctx.config.layer_id,
grad_output,
output,
input,
input_mask,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias,
norm_weight,
norm_bias,
)
return (
grad_input,
None,
grad_in_proj_weight,
grad_in_proj_bias,
grad_out_proj_weight,
grad_out_proj_bias,
grad_norm_weight,
grad_norm_bias,
None,
)
class MultiHeadAttention(nn.Module):
"""Initialize the MultiHeadAttention.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments:
hidden_size: Total dimension of hidden_size.
nhead: Number of parallel attention heads.
batch_size: Batch Size for one forward
max_seq_len: Max length of input sequence
dropout: Dropout probability
norm_first: perform LayerNorms before attention
"""
layer_id = 0
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
super(MultiHeadAttention, self).__init__()
self.config = Config(
batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16
)
check_config(self.config)
self.pg = pg
self.pg_size = 1
if self.pg:
self.pg_size = pg.size()
self.config.layer_id = MultiHeadAttention.layer_id
MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1
# Load cuda modules if needed
global colossal_multihead_attention
if colossal_multihead_attention is None:
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
multihead_attention = MultiHeadAttnBuilder().load()
colossal_multihead_attention = multihead_attention
# create the layer in cuda kernels.
cuda_module = colossal_multihead_attention
create_layer_func = (
cuda_module.create_multihead_attention_fp16
if self.config.fp16
else cuda_module.create_multihead_attention_fp32
)
create_layer_func(
self.config.layer_id,
self.config.max_batch_tokens,
self.config.max_seq_len,
self.config.hidden_size,
self.config.nhead,
self.config.attn_prob_dropout_ratio,
self.config.hidden_dropout_ratio,
self.config.norm_first,
self.pg,
)
hs = self.config.hidden_size
self.precision = torch.float32
if self.config.fp16:
self.precision = torch.half
self.hs_per_rank = int(hs / self.pg_size)
self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs))
self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank))
self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank))
self.out_proj_bias = nn.Parameter(torch.Tensor(hs))
self.norm_weight = nn.Parameter(torch.Tensor(hs))
self.norm_bias = nn.Parameter(torch.Tensor(hs))
self.reset_parameters()
torch.cuda.empty_cache()
def calc_bound(self, w):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
bound = 1.0 / math.sqrt(fan_in)
return bound
def reset_parameters(self):
hs = self.config.hidden_size
nn.init.zeros_(self.out_proj_bias)
nn.init.ones_(self.norm_weight)
nn.init.zeros_(self.norm_bias)
if self.pg_size > 1:
rank_in_pg = torch.distributed.get_rank(self.pg)
attn_qkvw_global = torch.empty(hs * 3, hs)
attn_qkvb_global = torch.empty(hs * 3)
nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0))
bound = self.calc_bound(attn_qkvw_global)
nn.init.uniform_(attn_qkvb_global, -bound, bound)
attn_qkvw_global = attn_qkvw_global.cuda()
attn_qkvb_global = attn_qkvb_global.cuda()
torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg)
torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg)
attn_qkvw_global = attn_qkvw_global.cpu()
attn_qkvb_global = attn_qkvb_global.cpu()
with torch.no_grad():
self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), :
]
)
self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[
:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)
]
)
attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0)
attn_ow_global = attn_ow_global.cuda()
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
attn_ow_global = attn_ow_global.cpu()
with torch.no_grad():
self.out_proj_weight.copy_(
attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)]
)
else:
attn_qkvw = self.in_proj_weight.view(-1, hs)
nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0))
bound = self.calc_bound(attn_qkvw)
nn.init.uniform_(self.in_proj_bias, -bound, bound)
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
def state_dict(self, destination=None, prefix="", keep_vars=False):
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
return destination
def forward(self, hidden_states, encoder_padding_mask):
self.config.training = self.training
self.config.is_grad_enabled = torch.is_grad_enabled()
hidden_states = hidden_states.contiguous()
encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()
bs, sl, dim = hidden_states.size()
if bs * sl > self.config.max_batch_tokens:
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
if sl > self.config.max_seq_len:
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
if len(encoder_padding_mask.size()) == 1:
assert bs == 1 and sl == encoder_padding_mask.size(0)
else:
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
output = MultiHeadAttention1DFunc.apply(
hidden_states,
encoder_padding_mask,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.norm_weight,
self.norm_bias,
self.config,
)
return output.to(self.precision)

View File

@ -0,0 +1 @@
../../extensions

View File

@ -1,21 +0,0 @@
from abc import ABC, abstractmethod
from typing import Callable
class BaseExtension(ABC):
@abstractmethod
def requires_build(self) -> bool:
pass
@abstractmethod
def build(self) -> None:
pass
@abstractmethod
def load(self) -> Callable:
pass
def fetch(self) -> Callable:
if self.requires_build:
self.build()
return self.load()

View File

@ -1,4 +0,0 @@
from .arm_extension import ArmCPUAdamExtension
from .x86_extension import X86CPUAdamExtension
__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"]

View File

@ -1,53 +0,0 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
class ArmCPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = ArmCPUAdamBuilder()
self._requires_build = False
@property
def requires_build(self) -> bool:
return self._requires_build
def build(self):
self.kernel_builder.build()
self._requires_build = True
def load(self):
return self.kernel_builder.load()
class ArmCPUAdamBuilder(ExtensionBuilder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"
def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes")]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []

View File

@ -1,65 +0,0 @@
from ..base_extension import BaseExtension
from ..extension_builder import ExtensionBuilder
from ..utils import append_nvcc_threads
class X86CPUAdamExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
self.kernel_builder = X86CPUAdamBuilder()
self._requires_build = False
@property
def requires_build(self) -> bool:
return self._requires_build
def build(self):
self.kernel_builder.build()
self._requires_build = True
def load(self):
return self.kernel_builder.load()
class X86CPUAdamBuilder(ExtensionBuilder):
NAME = "cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"
def __init__(self):
super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-lcudart",
"-lcublas",
"-g",
"-Wno-reorder",
"-fopenmp",
"-march=native",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
]
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags
return append_nvcc_threads(ret)

View File

@ -1,243 +0,0 @@
# This code has been adapted from the DeepSpeed library.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import importlib
import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
class ExtensionBuilder(ABC):
"""
Builder is the base class to build extensions for PyTorch.
Args:
name (str): the name of the kernel to be built
prebuilt_import_path (str): the path where the extension is installed during pip install
"""
ext_type: str = "cuda"
def __init__(self, name: str, prebuilt_import_path: str):
self.name = name
self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# we store the op as an attribute to avoid repeated building and loading
self.cached_op_module = None
assert prebuilt_import_path.startswith(
"colossalai._C"
), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}"
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the colossalai root directory and return the absolute path.
"""
op_builder_module_path = Path(__file__).parent
# if we install from source
# the current file path will be op_builder/builder.py
# if we install via pip install colossalai
# the current file path will be colossalai/kernel/op_builder/builder.py
# this is because that the op_builder inside colossalai is a symlink
# this symlink will be replaced with actual files if we install via pypi
# thus we cannot tell the colossalai root directory by checking whether the op_builder
# is a symlink, we can only tell whether it is inside or outside colossalai
if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"):
root_path = op_builder_module_path.parent.parent
elif str(op_builder_module_path).endswith("colossalai/kernel/extensions"):
root_path = op_builder_module_path.parent.parent
else:
root_path = op_builder_module_path.parent.joinpath("colossalai")
code_abs_path = root_path.joinpath(code_path)
return str(code_abs_path)
def get_cuda_home_include(self):
"""
return include path inside the cuda home.
"""
from torch.utils.cpp_extension import CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path)
# functions must be overrided begin
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
raise NotImplementedError
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
# functions must be overrided over
def strip_empty_entries(self, args):
"""
Drop any empty strings from the list of compile and link flags
"""
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def check_runtime_build_environment(self):
"""
Check whether the system environment is ready for extension compilation.
"""
try:
from torch.utils.cpp_extension import CUDA_HOME
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
CUDA_HOME = None
if not TORCH_AVAILABLE:
raise ModuleNotFoundError(
"PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions"
)
if CUDA_HOME is None:
raise RuntimeError(
"CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions"
)
# make sure CUDA is available for compilation during
cuda_available = check_cuda_availability()
if not cuda_available:
raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.")
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
check_system_pytorch_cuda_match(CUDA_HOME)
def build(self, verbose: Optional[bool] = None):
"""
If the kernel is not built during pip install, it will build the kernel.
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
kernel is built during pip install, it can be accessed through `colossalai._C`.
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
if verbose is None:
verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1"
try:
# if the kernel has been pre-built during installation
# we just directly import it
op_module = self.import_op()
if verbose:
print_rank_0(
f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building."
)
except ImportError:
# check environment
if self.ext_type == "cuda":
self.check_runtime_build_environment()
# time the kernel compilation
start_build = time.time()
# construct the build directory
import torch
from torch.utils.cpp_extension import load
torch_version_major = torch.__version__.split(".")[0]
torch_version_minor = torch.__version__.split(".")[1]
torch_cuda_version = torch.version.cuda
home_directory = os.path.expanduser("~")
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
build_directory = os.path.join(home_directory, extension_directory)
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now")
# load the kernel
op_module = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=build_directory,
verbose=verbose,
)
build_duration = time.time() - start_build
# log jit compilation time
if verbose:
print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds")
# cache the built/loaded kernel
self.cached_op_module = op_module
def load(self, verbose: Optional[bool] = None):
"""
load the kernel during runtime.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
# if the kernel has be compiled and cached, we directly use it
assert self.cached_op_module is not None, "Please build the kernel first before loading it."
return self.cached_op_module
def builder(self) -> Union["CUDAExtension", "CppExtension"]:
"""
get a CUDAExtension instance used for setup.py
"""
from torch.utils.cpp_extension import CppExtension, CUDAExtension
if self.ext_type == "cpp":
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
return CUDAExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
"cxx": self.strip_empty_entries(self.cxx_flags()),
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
},
)

View File

@ -1,19 +0,0 @@
from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension
from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension
from .npu_sdpa_attn_extension import NpuSdpaAttnExtension
from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension
from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad
__all__ = [
"CudaFlashAttnExtension",
"CudaMemoryEfficentAttnExtension",
"NpuSdpaAttnExtension",
"NpuTriangleAttnExtension",
"HAS_FLASH_ATTN",
"HAS_MEM_EFF_ATTN",
"HAS_NPU_TRIANGLE_ATTENTION",
"Unpad",
"AttnMaskType",
"Repad",
"SeqLenInfo",
]

View File

@ -1,100 +0,0 @@
from typing import Optional
import torch
from ..base_extension import BaseExtension
from ..utils import print_rank_0
from .utils import SeqLenInfo
def is_ampere_or_better_gpu():
# Check Ampere GPUs or newer
if torch.cuda.is_available():
device = torch.device("cuda")
properties = torch.cuda.get_device_properties(device)
if properties.major >= 8: # Ampere GPUs or newer
return True
return False
HAS_FLASH_ATTN = False
ERROR_MSG = None
if is_ampere_or_better_gpu():
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
HAS_FLASH_ATTN = True
except ImportError:
ERROR_MSG = "ImportError: please install flash_attn from https://github.com/HazyResearch/flash-attention"
else:
ERROR_MSG = "ImportError: FlashAttention only supports Ampere GPUs or newer."
if HAS_FLASH_ATTN:
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
"""
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
if padded:
if seq_len_info_kv == None:
seq_len_info_kv = seq_len_info_q
attn_out = flash_attn_varlen_func(
q,
k,
v,
seq_len_info_q.cu_seqlens,
seq_len_info_kv.cu_seqlens,
seq_len_info_q.max_seqlen,
seq_len_info_kv.max_seqlen,
dropout_p,
scale,
causal,
)
else:
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
return attn_out
class CudaFlashAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self):
return False
def build(self):
pass
def is_available(self):
if HAS_FLASH_ATTN == False:
print_rank_0(ERROR_MSG)
return HAS_FLASH_ATTN
def load(self):
return flash_attention

View File

@ -1,91 +0,0 @@
from typing import Optional
import torch
from ..base_extension import BaseExtension
from ..utils import print_rank_0
from .utils import SeqLenInfo
HAS_MEM_EFF_ATTN = False
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
HAS_MEM_EFF_ATTN = True
except ImportError:
pass
if HAS_MEM_EFF_ATTN:
"""
A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: SeqLenInfo,
seq_len_info_kv: SeqLenInfo,
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out
class CudaMemoryEfficentAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_MEM_EFF_ATTN == False:
print_rank_0("ImportError: please install xformers from https://github.com/facebookresearch/xformers")
return HAS_MEM_EFF_ATTN
def load(self):
return mem_eff_attention

View File

@ -1,60 +0,0 @@
import torch
from einops import rearrange
from ..base_extension import BaseExtension
def npu_sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
):
"""
The scaled dot product attention.
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=origin_attn_mask,
dropout_p=dropout_p,
is_causal=origin_attn_mask is None,
scale=scale,
)
output = rearrange(output, "b h s d -> b s (h d)")
return output
class NpuSdpaAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def load(self):
return npu_sdpa_attention

View File

@ -1,141 +0,0 @@
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from ..base_extension import BaseExtension
from ..utils import print_rank_0
HAS_NPU_TRIANGLE_ATTENTION = False
try:
from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax
HAS_NPU_TRIANGLE_ATTENTION = True
except ImportError:
pass
if HAS_NPU_TRIANGLE_ATTENTION:
def npu_triangle_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
scale: float = 1.0,
causal=None,
padded=None,
block_size=512,
):
"""
The triangle attention reduces the attention calculation of the mask
part by dividing the q, k, and v matrices into blocks
Arguments:
block_size: The size of the inverted triangle block, the default is 512,
the smaller the block_size, the more calculations will be reduced,
but the number of small operators will be increased
masked_softmax_func: mask function to be applied.
dropout_func: dropout function to be applied.
"""
def compute_attn(q_layer, k_layer, v_layer, mask_tmp):
# [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size]
cur_sim = torch.matmul(q_layer, k_layer)
attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp)
# attention dropout
if dropout_p > 0:
attention_probs = torch.nn.functional.dropout(
attention_probs, p=dropout_p, training=attention_probs.require_grad
)
# [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd]
context_layer_tmp = torch.matmul(attention_probs, v_layer)
return context_layer_tmp
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)]
origin_attn_mask = origin_attn_mask.to(torch.bool)
# input shape: [b, hn, sq, hd]
bsz, head_num, sequence_len, head_dim = k.shape
sparse_groups = sequence_len // block_size
# Determine whether blocks size can be divided by sequence_length
divisible_flag = sequence_len == block_size * sparse_groups
k = k.transpose(2, 3).contiguous()
if divisible_flag:
q_tmp_layers = torch.chunk(q, sparse_groups, 2)
k_tmp_layers = torch.chunk(k, sparse_groups, 3)
v_tmp_layers = torch.chunk(v, sparse_groups, 2)
else:
seq_tmp = block_size * sparse_groups
q_last = q[:, :, seq_tmp:, :].contiguous()
mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous()
q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2)
k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3)
v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2)
context_list_tmp, k_tmp, v_tmp = [], (), ()
for i in range(sparse_groups):
# compute slice shape of q k v for each loop
q_begin, q_end = i * block_size, (i + 1) * block_size
kv_begin, kv_end = 0, (i + 1) * block_size
q_tmp = q_tmp_layers[i]
# slice k and v
if i == 0:
k_tmp = k_tmp_layers[i].contiguous()
v_tmp = v_tmp_layers[i].contiguous()
else:
k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous()
v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous()
mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous()
context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp)
context_list_tmp.append(context_layer_tmp)
if not divisible_flag:
# circumstances that cannot be divisible
context_layer_tmp = compute_attn(q_last, k, v, mask_last)
context_list_tmp.append(context_layer_tmp)
context_layer = torch.cat(context_list_tmp, 2)
new_context_layer_shape = (bsz, sequence_len, head_num * head_dim)
context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True)
# =========================
# Context layer. [b, sq, hp]
# =========================
return context_layer
class NpuTriangleAttnExtension(BaseExtension):
def __init__(self) -> None:
super().__init__()
@property
def requires_build(self) -> bool:
return False
def build(self):
pass
def is_available(self):
if HAS_NPU_TRIANGLE_ATTENTION == False:
print_rank_0(
"ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api."
)
return HAS_NPU_TRIANGLE_ATTENTION
def load(self):
return npu_triangle_attention

View File

@ -1,91 +0,0 @@
import enum
from dataclasses import dataclass
from typing import Iterable, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3

View File

@ -0,0 +1,109 @@
import warnings
from typing import List
from .extensions import (
CpuAdamArmExtension,
CpuAdamX86Extension,
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
FusedOptimizerCudaExtension,
LayerNormCudaExtension,
MoeCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
)
from .extensions.base_extension import _Extension
__all__ = [
"KernelLoader",
"CPUAdamLoader",
"LayerNormLoader",
"MoeLoader",
"FusedOptimizerLoader",
"ScaledMaskedSoftmaxLoader",
"ScaledUpperTriangleMaskedSoftmaxLoader",
]
class KernelLoader:
"""
An abstract class which offers encapsulation to the kernel loading process.
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
REGISTRY: List[_Extension] = []
@classmethod
def register_extension(cls, extension: _Extension):
"""
This classmethod is an extension point which allows users to register their customized
kernel implementations to the loader.
Args:
extension (_Extension): the extension to be registered.
"""
cls.REGISTRY.append(extension)
def load(self, ext_name: str = None):
"""
Load the kernel according to the current machine.
Args:
ext_name (str): the name of the extension to be loaded. If not specified, the loader
will try to look for an kernel available on the current machine.
"""
exts = [ext_cls() for ext_cls in self.__class__.REGISTRY]
# look for exts which can be built/loaded on the current machine
if ext_name:
usable_exts = list(filter(lambda ext: ext.name == ext_name, exts))
else:
usable_exts = []
for ext in exts:
if ext.is_hardware_available():
# make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible()
usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
if len(usable_exts) > 1:
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True)
warnings.warn(
f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}"
)
return usable_exts[0].load()
class CPUAdamLoader(KernelLoader):
REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension]
class LayerNormLoader(KernelLoader):
REGISTRY = [LayerNormCudaExtension]
class MoeLoader(KernelLoader):
REGISTRY = [MoeCudaExtension]
class FusedOptimizerLoader(KernelLoader):
REGISTRY = [FusedOptimizerCudaExtension]
class ScaledMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension]
class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]

View File

@ -1 +0,0 @@
../../op_builder

View File

@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
@ -28,7 +28,7 @@ def load_fused_optim():
global fused_optim global fused_optim
if fused_optim is None: if fused_optim is None:
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):

View File

@ -11,7 +11,6 @@ from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.context.parallel_context import global_context as gpc from colossalai.legacy.context.parallel_context import global_context as gpc
@ -23,6 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule

View File

@ -8,13 +8,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.legacy.context import seed from colossalai.legacy.context import seed
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
from colossalai.legacy.registry import LAYERS from colossalai.legacy.registry import LAYERS
from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax
@LAYERS.register_module @LAYERS.register_module

View File

@ -96,9 +96,9 @@ def _calc_l2_norm(grads):
global fused_optim global fused_optim
if fused_optim is None: if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:

View File

@ -11,9 +11,9 @@ MOE_KERNEL = None
def load_moe(): def load_moe():
global MOE_KERNEL global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder from colossalai.kernel.kernel_loader import MoeLoader
MOE_KERNEL = MOEBuilder().load() MOE_KERNEL = MoeLoader().load()
class AllGather(torch.autograd.Function): class AllGather(torch.autograd.Function):
@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function):
class HierarchicalAllToAll(torch.autograd.Function): class HierarchicalAllToAll(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor:
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
""" """
Returns: Returns:
outputs: Tensor outputs: Tensor
@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function):
if tokens_grad.dtype != torch.float32: if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32) tokens_grad = tokens_grad.to(torch.float32)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, d_expert, d_logits = MOE_KERNEL.combine_backward(
mask, dest_idx) ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx
)
if d_expert.dtype != ctx.dtype: if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype) d_expert = d_expert.to(ctx.dtype)

View File

@ -1,75 +1,97 @@
import enum
import math import math
from collections import OrderedDict import warnings
from typing import Optional from dataclasses import dataclass
from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.kernel.kernel_loader import FlashAttentionLoader
from .base_kernel_loader import BaseKernelLoader
from .extensions.flash_attention import (
AttnMaskType,
CudaFlashAttnExtension,
CudaMemoryEfficentAttnExtension,
NpuSdpaAttnExtension,
NpuTriangleAttnExtension,
Repad,
SeqLenInfo,
Unpad,
)
from .extensions.utils import print_rank_0
class FlashAttentionLoader(BaseKernelLoader): @dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class Unpad(torch.autograd.Function):
""" """
FlashAttention Loader Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
""" """
def __init__(self): @staticmethod
super().__init__( def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
# extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx ctx.save_for_backward(indices)
extension_map=OrderedDict( # [b, s, ...]
cuda_flash_attn=CudaFlashAttnExtension, assert tensor.ndim >= 3
cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, ctx.bsz = tensor.shape[0]
npu_sdpa_attn=NpuSdpaAttnExtension, out = rearrange(tensor, "b s ... -> (b s) ...")
npu_triangle_attn=NpuTriangleAttnExtension, ctx.shape = out.shape
), # [ntokens, ...]
supported_device=["cuda", "npu"], return out[indices]
)
def fetch_kernel(self, backend: str = None): @staticmethod
if backend is not None: def backward(ctx, grad_output):
if not self._extension_map[backend]().is_available(): (indices,) = ctx.saved_tensors
raise Exception(f"{backend} is not available for flash attention.") # [ntokens, ...]
return self._extension_map[backend]().fetch() grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
kernel = None
accelerator_name = get_accelerator().name class Repad(torch.autograd.Function):
assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention." """
for extension_name, extension in self._extension_map.items(): Adapted from
if extension_name.startswith(accelerator_name): https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
if extension().is_available(): """
kernel = extension().fetch()
break @staticmethod
if kernel is None: def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
raise Exception("No extension for flash attention is supported") ctx.save_for_backward(indices)
return kernel # [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
class ColoAttention(torch.nn.Module): class ColoAttention(torch.nn.Module):
@ -84,7 +106,7 @@ class ColoAttention(torch.nn.Module):
self.scale = 1 / math.sqrt(embed_dim // num_heads) self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout self.dropout = dropout
self.attn = FlashAttentionLoader().fetch_kernel() self.attn = FlashAttentionLoader().load()
@staticmethod @staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
@ -120,8 +142,10 @@ class ColoAttention(torch.nn.Module):
if self.attn.__name__ == "flash_attention" and ( if self.attn.__name__ == "flash_attention" and (
query.dtype not in [torch.float16, torch.bfloat16] or bias != None query.dtype not in [torch.float16, torch.bfloat16] or bias != None
): ):
print_rank_0("flash attention is not applicable, switch to memory effcient attention") warnings.warn(
self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn") f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
)
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1 causal = attn_mask_type is not None and attn_mask_type.value > 1

View File

@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder from colossalai.kernel.kernel_loader import LayerNormLoader
try: try:
from colossalai._C import layer_norm from colossalai._C import layer_norm
@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
global layer_norm global layer_norm
if layer_norm is None: if layer_norm is None:
layer_norm = LayerNormBuilder().load() layer_norm = LayerNormLoader().load()
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.layernorm_op = layer_norm ctx.layernorm_op = layer_norm
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input_, weight_, bias_, mean, invvar)

View File

@ -0,0 +1,184 @@
# This code from NVIDIA Megatron:
# with minor changes.
import enum
import torch
import torch.nn as nn
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax:
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
scale_t = torch.tensor([scale])
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
Fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: Flag to indicate if input in fp16 data format.
input_in_bf16: Flag to indicate if input in bf16 data format.
attn_mask_type: Attention mask type (pad or causal)
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
mask_func: Mask function to be applied.
softmax_in_fp32: If True, softmax in performed at fp32 precision.
scale: Scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)

View File

@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from colossalai.kernel import CPUAdamLoader from colossalai.kernel.kernel_loader import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer

View File

@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0 self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
if multi_tensor_applier.available: if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])

View File

@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer):
) )
super(FusedLAMB, self).__init__(params, defaults) super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available: if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
# Skip buffer # Skip buffer

View File

@ -72,9 +72,9 @@ class FusedSGD(Optimizer):
self.wd_after_momentum = wd_after_momentum self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available: if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
# Skip buffer # Skip buffer
self._dummy_overflow_buf = torch.tensor( self._dummy_overflow_buf = torch.tensor(

View File

@ -2,7 +2,7 @@ from typing import Any, Optional
import torch import torch
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam from .cpu_adam import CPUAdam
@ -85,7 +85,7 @@ class HybridAdam(CPUAdam):
nvme_offload_dir, nvme_offload_dir,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
fused_optim = FusedOptimBuilder().load() fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])

View File

@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule from .base import PipelineSchedule

View File

@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
from ._utils import ( from ._utils import (
detach, detach,

View File

@ -62,7 +62,7 @@ def forward_fn():
def get_blip2_flash_attention_forward(): def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel import ColoAttention from colossalai.nn.layer.colo_attention import ColoAttention
def forward( def forward(
self: Blip2Attention, self: Blip2Attention,

View File

@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
from colossalai.kernel import AttnMaskType, ColoAttention from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention from .chatglm2_6b.modeling_chatglm import CoreAttention

View File

@ -719,7 +719,7 @@ class GPT2PipelineForwards:
def get_gpt2_flash_attention_forward(): def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel import AttnMaskType, ColoAttention from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size): def split_heads(tensor, num_heads, attn_head_size):
""" """

View File

@ -530,7 +530,7 @@ class GPTJPipelineForwards:
def get_gptj_flash_attention_forward(): def get_gptj_flash_attention_forward():
from transformers.models.gptj.modeling_gptj import GPTJAttention from transformers.models.gptj.modeling_gptj import GPTJAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_attention_heads, attn_head_size, rotary): def split_heads(tensor, num_attention_heads, attn_head_size, rotary):
""" """

View File

@ -1,5 +1,5 @@
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -420,7 +420,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(shard_config: ShardConfig): def get_llama_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel import AttnMaskType, ColoAttention from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
llama_version = 2 llama_version = 2
try: try:

View File

@ -6,7 +6,7 @@ import torch
def get_mistral_flash_attention_forward(): def get_mistral_flash_attention_forward():
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def forward( def forward(
self: MistralAttention, self: MistralAttention,

View File

@ -514,7 +514,7 @@ class OPTPipelineForwards:
def get_opt_flash_attention_forward(): def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel import AttnMaskType, ColoAttention from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def forward( def forward(
self: OPTAttention, self: OPTAttention,

View File

@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward(): def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel import ColoAttention from colossalai.nn.layer.colo_attention import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)

View File

@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward(): def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel import AttnMaskType, ColoAttention from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()

View File

@ -4,6 +4,7 @@ from .common import (
disposable, disposable,
ensure_path_exists, ensure_path_exists,
free_storage, free_storage,
get_current_device,
is_ddp_ignored, is_ddp_ignored,
set_seed, set_seed,
) )
@ -22,5 +23,6 @@ __all__ = [
"_cast_float", "_cast_float",
"free_storage", "free_storage",
"set_seed", "set_seed",
"get_current_device",
"is_ddp_ignored", "is_ddp_ignored",
] ]

View File

@ -10,6 +10,15 @@ from typing import Callable
import numpy as np import numpy as np
import torch import torch
from colossalai.accelerator import get_accelerator
def get_current_device():
"""
A wrapper function for accelerator's API for backward compatibility.
"""
return get_accelerator().get_current_device()
def ensure_path_exists(filename: str): def ensure_path_exists(filename: str):
# ensure the path exists # ensure the path exists

View File

@ -190,8 +190,10 @@ class Chunk:
def device_type(self) -> str: def device_type(self) -> str:
if self.chunk_temp is not None: if self.chunk_temp is not None:
return self.chunk_temp.device.type return self.chunk_temp.device.type
else: elif self.is_gathered or self.cuda_shard is not None:
return get_accelerator().name return get_accelerator().name
else:
return "cpu"
@property @property
def payload(self) -> torch.Tensor: def payload(self) -> torch.Tensor:

View File

@ -3,13 +3,13 @@ import inspect
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.kernel import LayerNorm
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.legacy.pipeline.utils import partition_uniform
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding
from .layers.init_method import init_normal, output_init_normal from .layers.init_method import init_normal, output_init_normal

View File

@ -3,9 +3,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from loss_func.cross_entropy import vocab_cross_entropy from loss_func.cross_entropy import vocab_cross_entropy
from colossalai.kernel import LayerNorm
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from .linear import Linear from .linear import Linear
from .pooler import Pooler from .pooler import Pooler

View File

@ -8,12 +8,12 @@ from lr_scheduler import AnnealingLR
from model.bert import BertForPretrain, build_pipeline_bert from model.bert import BertForPretrain, build_pipeline_bert
import colossalai import colossalai
from colossalai.kernel import LayerNorm
from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import is_using_pp from colossalai.legacy.utils import is_using_pp
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
from colossalai.nn.optimizer import FusedAdam from colossalai.nn.optimizer import FusedAdam
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer

36
extensions/__init__.py Normal file
View File

@ -0,0 +1,36 @@
from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
from .flash_attention import (
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
)
from .layernorm import LayerNormCudaExtension
from .moe import MoeCudaExtension
from .optimizer import FusedOptimizerCudaExtension
from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension
ALL_EXTENSIONS = [
CpuAdamArmExtension,
CpuAdamX86Extension,
LayerNormCudaExtension,
MoeCudaExtension,
FusedOptimizerCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
FlashAttentionDaoCudaExtension,
FlashAttentionXformersCudaExtension,
FlashAttentionNpuExtension,
]
__all__ = [
"CpuAdamArmExtension",
"CpuAdamX86Extension",
"LayerNormCudaExtension",
"MoeCudaExtension",
"FusedOptimizerCudaExtension",
"ScaledMaskedSoftmaxCudaExtension",
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
"FlashAttentionDaoCudaExtension",
"FlashAttentionXformersCudaExtension",
"FlashAttentionNpuExtension",
]

View File

@ -0,0 +1,82 @@
import hashlib
import os
from abc import ABC, abstractmethod
from typing import Union
__all__ = ["_Extension"]
class _Extension(ABC):
def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1):
self._name = name
self._support_aot = support_aot
self._support_jit = support_jit
self.priority = priority
@property
def name(self):
return self._name
@property
def support_aot(self):
return self._support_aot
@property
def support_jit(self):
return self._support_jit
@staticmethod
def get_jit_extension_folder_path():
"""
Kernels which are compiled during runtime will be stored in the same cache folder for reuse.
The folder is in the path ~/.cache/colossalai/torch_extensions/<cache-folder>.
The name of the <cache-folder> follows a common format:
torch<torch_version_major>.<torch_version_minor>_<device_name><device_version>-<hash>
The <hash> suffix is the hash value of the path of the `colossalai` file.
"""
import torch
import colossalai
from colossalai.accelerator import get_accelerator
# get torch version
torch_version_major = torch.__version__.split(".")[0]
torch_version_minor = torch.__version__.split(".")[1]
# get device version
device_name = get_accelerator().name
device_version = get_accelerator().get_version()
# use colossalai's file path as hash
hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest()
# concat
home_directory = os.path.expanduser("~")
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}"
cache_directory = os.path.join(home_directory, extension_directory)
return cache_directory
@abstractmethod
def is_hardware_available(self) -> bool:
"""
Check if the hardware required by the kernel is available.
"""
@abstractmethod
def assert_hardware_compatible(self) -> bool:
"""
Check if the hardware required by the kernel is compatible.
"""
@abstractmethod
def build_aot(self) -> Union["CppExtension", "CUDAExtension"]:
pass
@abstractmethod
def build_jit(self) -> None:
pass
@abstractmethod
def load(self):
pass

134
extensions/cpp_extension.py Normal file
View File

@ -0,0 +1,134 @@
import importlib
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List
from .base_extension import _Extension
__all__ = ["_CppExtension"]
class _CppExtension(_Extension):
def __init__(self, name: str, priority: int = 1):
super().__init__(name, support_aot=True, support_jit=True, priority=priority)
# we store the op as an attribute to avoid repeated building and loading
self.cached_op = None
# build-related variables
self.prebuilt_module_path = "colossalai._C"
self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path("csrc"), path)
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the colossalai root directory and return the absolute path.
"""
# get the current file path
# iteratively check the parent directory
# if the parent directory is "extensions", then the current file path is the root directory
# otherwise, the current file path is inside the root directory
current_file_path = Path(__file__)
while True:
if current_file_path.name == "extensions":
break
else:
current_file_path = current_file_path.parent
extension_module_path = current_file_path
code_abs_path = extension_module_path.joinpath(code_path)
return str(code_abs_path)
# functions must be overrided over
def strip_empty_entries(self, args):
"""
Drop any empty strings from the list of compile and link flags
"""
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def build_aot(self) -> "CppExtension":
from torch.utils.cpp_extension import CppExtension
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
def build_jit(self) -> None:
from torch.utils.cpp_extension import load
build_directory = _Extension.get_jit_extension_folder_path()
build_directory = Path(build_directory)
build_directory.mkdir(parents=True, exist_ok=True)
# check if the kernel has been built
compiled_before = False
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
if kernel_file_path.exists():
compiled_before = True
# load the kernel
if compiled_before:
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
else:
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
build_start = time.time()
op_kernel = load(
name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_ldflags=[],
build_directory=str(build_directory),
)
build_duration = time.time() - build_start
if compiled_before:
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
else:
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
return op_kernel
# functions must be overrided begin
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
def load(self):
try:
op_kernel = self.import_op()
except ImportError:
# if import error occurs, it means that the kernel is not pre-built
# so we build it jit
op_kernel = self.build_jit()
return op_kernel

View File

@ -0,0 +1,5 @@
from .cpu_adam_arm import CpuAdamArmExtension
from .cpu_adam_x86 import CpuAdamX86Extension
__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension']

View File

@ -0,0 +1,41 @@
import platform
from ..cpp_extension import _CppExtension
class CpuAdamArmExtension(_CppExtension):
def __init__(self):
super().__init__(name="cpu_adam_arm")
def is_hardware_available(self) -> bool:
# only arm allowed
return platform.machine() == "aarch64"
def assert_hardware_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "aarch64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}"
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("arm/cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return []
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []

View File

@ -1,19 +1,27 @@
from .builder import Builder import platform
from .utils import append_nvcc_threads
from ..cuda_extension import _CudaExtension
from ..utils import append_nvcc_threads
class CPUAdamBuilder(Builder): class CpuAdamX86Extension(_CudaExtension):
NAME = "cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"
def __init__(self): def __init__(self):
super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) super().__init__(name="cpu_adam_x86")
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
def is_hardware_available(self) -> bool:
return platform.machine() == "x86_64" and super().is_hardware_available()
def assert_hardware_compatible(self) -> None:
arch = platform.machine()
assert (
arch == "x86_64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
super().assert_hardware_compatible()
# necessary 4 functions # necessary 4 functions
def sources_files(self): def sources_files(self):
ret = [ ret = [
self.csrc_abs_path("cpu_adam.cpp"), self.csrc_abs_path("cuda/cpu_adam.cpp"),
] ]
return ret return ret

Some files were not shown because too many files have changed in this diff Show More