mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)
* refactor compilation mechanism and unified multi hw * fix file path bug * add init.py to make pybind a module to avoid relative path error caused by softlink * delete duplicated micros * fix micros bug in gcc
This commit is contained in:
4
extensions/pybind/softmax/__init__.py
Normal file
4
extensions/pybind/softmax/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension
|
||||
from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension
|
||||
|
||||
__all__ = ["ScaledMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension"]
|
54
extensions/pybind/softmax/scaled_masked_softmax.cpp
Normal file
54
extensions/pybind/softmax/scaled_masked_softmax.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
|
||||
|
||||
return fwd_cuda(input, mask, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
|
||||
m.def("backward", &bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
|
||||
m.def("get_batch_per_block", &get_batch_per_block,
|
||||
"Return Batch per block size.");
|
||||
}
|
28
extensions/pybind/softmax/scaled_masked_softmax_cuda.py
Normal file
28
extensions/pybind/softmax/scaled_masked_softmax_cuda.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from ...cuda_extension import _CudaExtension
|
||||
from ...utils import append_nvcc_threads
|
||||
|
||||
|
||||
class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="scaled_masked_softmax_cuda")
|
||||
|
||||
def sources_files(self):
|
||||
ret = [self.csrc_abs_path(fname) for fname in ["kernel/cuda/scaled_masked_softmax_kernel.cu"]] + [
|
||||
self.pybind_abs_path("softmax/scaled_masked_softmax.cpp")
|
||||
]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ["-O3"] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
"-std=c++14",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
|
||||
]
|
||||
ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + super().nvcc_flags()
|
||||
return append_nvcc_threads(ret)
|
@@ -0,0 +1,44 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return fwd_cuda(input, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("backward", &bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
}
|
@@ -0,0 +1,30 @@
|
||||
from ...cuda_extension import _CudaExtension
|
||||
from ...utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||
def __init__(self):
|
||||
super().__init__(name="scaled_upper_triangle_masked_softmax_cuda")
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname)
|
||||
for fname in [
|
||||
"kernel/cuda/scaled_upper_triang_masked_softmax_kernel.cu",
|
||||
]
|
||||
] + [self.pybind_abs_path("softmax/scaled_upper_triang_masked_softmax.cpp")]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ["-O3"] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags()
|
||||
return append_nvcc_threads(ret)
|
Reference in New Issue
Block a user