mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[setup] support pre-build and jit-build of cuda kernels (#2374)
* [setup] support pre-build and jit-build of cuda kernels * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
31
op_builder/README.md
Normal file
31
op_builder/README.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Build PyTorch Extensions
|
||||
|
||||
## Overview
|
||||
|
||||
Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users.
|
||||
|
||||
1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1`
|
||||
2. Build the extension during runtime
|
||||
|
||||
The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program.
|
||||
|
||||
These two methods have different advantages and disadvantages.
|
||||
Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration.
|
||||
Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load.
|
||||
|
||||
## PyTorch Extensions in Colossal-AI
|
||||
|
||||
As mentioned in the section above, our aim is to make these two methods coherently supported in Colossal-AI, meaning that for a kernel should be either built in `setup.py` or during runtime.
|
||||
There are mainly two functions used to build extensions.
|
||||
|
||||
1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`.
|
||||
2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime
|
||||
|
||||
Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong).
|
||||
|
||||
We have implemented the following conventions:
|
||||
|
||||
1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C`
|
||||
2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete)
|
||||
|
||||
When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered.
|
@@ -1,7 +1,23 @@
|
||||
from .cpu_adam import CPUAdamBuilder
|
||||
from .fused_optim import FusedOptimBuilder
|
||||
from .layernorm import LayerNormBuilder
|
||||
from .moe import MOEBuilder
|
||||
from .multi_head_attn import MultiHeadAttnBuilder
|
||||
from .scaled_upper_triang_masked_softmax import ScaledSoftmaxBuilder
|
||||
from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder
|
||||
|
||||
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledSoftmaxBuilder', 'MOEBuilder']
|
||||
ALL_OPS = {
|
||||
'cpu_adam': CPUAdamBuilder,
|
||||
'fused_optim': FusedOptimBuilder,
|
||||
'moe': MOEBuilder,
|
||||
'multi_head_attn': MultiHeadAttnBuilder,
|
||||
'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder,
|
||||
'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder,
|
||||
'layernorm': LayerNormBuilder,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder',
|
||||
'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder',
|
||||
'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder'
|
||||
]
|
||||
|
@@ -1,40 +1,49 @@
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_cuda_cc_flag() -> List:
|
||||
"""get_cuda_cc_flag
|
||||
class Builder(ABC):
|
||||
"""
|
||||
Builder is the base class to build extensions for PyTorch.
|
||||
|
||||
cc flag for your GPU arch
|
||||
Args:
|
||||
name (str): the name of the kernel to be built
|
||||
prebuilt_import_path (str): the path where the extension is installed during pip install
|
||||
"""
|
||||
|
||||
# only import torch when needed
|
||||
# this is to avoid importing torch when building on a machine without torch pre-installed
|
||||
# one case is to build wheel for pypi release
|
||||
import torch
|
||||
|
||||
cc_flag = []
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r'sm_(\d+)', arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60:
|
||||
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||
def __init__(self, name: str, prebuilt_import_path: str):
|
||||
self.name = name
|
||||
self.prebuilt_import_path = prebuilt_import_path
|
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
return cc_flag
|
||||
assert prebuilt_import_path.startswith('colossalai._C'), \
|
||||
f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}'
|
||||
|
||||
def relative_to_abs_path(self, code_path: str) -> str:
|
||||
"""
|
||||
This function takes in a path relative to the colossalai root directory and return the absolute path.
|
||||
"""
|
||||
op_builder_module_path = Path(__file__).parent
|
||||
|
||||
class Builder(object):
|
||||
|
||||
def colossalai_src_path(self, code_path):
|
||||
current_file_path = Path(__file__)
|
||||
if os.path.islink(current_file_path.parent):
|
||||
# symbolic link
|
||||
return os.path.join(current_file_path.parent.parent.absolute(), code_path)
|
||||
# if we install from source
|
||||
# the current file path will be op_builder/builder.py
|
||||
# if we install via pip install colossalai
|
||||
# the current file path will be colossalai/kernel/op_builder/builder.py
|
||||
# this is because that the op_builder inside colossalai is a symlink
|
||||
# this symlink will be replaced with actual files if we install via pypi
|
||||
# thus we cannot tell the colossalai root directory by checking whether the op_builder
|
||||
# is a symlink, we can only tell whether it is inside or outside colossalai
|
||||
if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'):
|
||||
root_path = op_builder_module_path.parent.parent
|
||||
else:
|
||||
return os.path.join(current_file_path.parent.parent.absolute(), "colossalai", "kernel", code_path)
|
||||
root_path = op_builder_module_path.parent.joinpath('colossalai')
|
||||
|
||||
code_abs_path = root_path.joinpath(code_path)
|
||||
return str(code_abs_path)
|
||||
|
||||
def get_cuda_home_include(self):
|
||||
"""
|
||||
@@ -46,47 +55,94 @@ class Builder(object):
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return cuda_include
|
||||
|
||||
def csrc_abs_path(self, path):
|
||||
return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path)
|
||||
|
||||
# functions must be overrided begin
|
||||
def sources_files(self):
|
||||
@abstractmethod
|
||||
def sources_files(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of source files for extensions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def include_dirs(self):
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def include_dirs(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of inlcude files for extensions.
|
||||
"""
|
||||
pass
|
||||
|
||||
def cxx_flags(self):
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def cxx_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of cxx compilation flags for extensions.
|
||||
"""
|
||||
pass
|
||||
|
||||
def nvcc_flags(self):
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def nvcc_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of nvcc compilation flags for extensions.
|
||||
"""
|
||||
pass
|
||||
|
||||
# functions must be overrided over
|
||||
|
||||
def strip_empty_entries(self, args):
|
||||
'''
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
'''
|
||||
return [x for x in args if len(x) > 0]
|
||||
|
||||
def import_op(self):
|
||||
"""
|
||||
This function will import the op module by its string name.
|
||||
"""
|
||||
return importlib.import_module(self.prebuilt_import_path)
|
||||
|
||||
def load(self, verbose=True):
|
||||
"""
|
||||
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
|
||||
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
|
||||
kernel is built during pip install, it can be accessed through `colossalai._C`.
|
||||
|
||||
load and compile cpu_adam lib at runtime
|
||||
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
|
||||
|
||||
Args:
|
||||
verbose (bool, optional): show detailed info. Defaults to True.
|
||||
"""
|
||||
import time
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
start_build = time.time()
|
||||
|
||||
op_module = load(name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_cuda_cflags=self.nvcc_flags(),
|
||||
extra_ldflags=[],
|
||||
verbose=verbose)
|
||||
try:
|
||||
op_module = self.import_op()
|
||||
if verbose:
|
||||
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
||||
except ImportError:
|
||||
# construct the build directory
|
||||
import torch
|
||||
torch_version_major = torch.__version__.split('.')[0]
|
||||
torch_version_minor = torch.__version__.split('.')[1]
|
||||
torch_cuda_version = torch.version.cuda
|
||||
home_directory = os.path.expanduser('~')
|
||||
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
|
||||
build_directory = os.path.join(home_directory, extension_directory)
|
||||
Path(build_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print("=========================================================================================")
|
||||
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
||||
print("=========================================================================================")
|
||||
|
||||
# load the kernel
|
||||
op_module = load(name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_cuda_cflags=self.nvcc_flags(),
|
||||
extra_ldflags=[],
|
||||
build_directory=build_directory,
|
||||
verbose=verbose)
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
@@ -94,17 +150,16 @@ class Builder(object):
|
||||
|
||||
return op_module
|
||||
|
||||
def builder(self, name) -> 'CUDAExtension':
|
||||
def builder(self) -> 'CUDAExtension':
|
||||
"""
|
||||
get a CUDAExtension instance used for setup.py
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
|
||||
return CUDAExtension(
|
||||
name=name,
|
||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources_files()],
|
||||
include_dirs=self.include_dirs(),
|
||||
extra_compile_args={
|
||||
'cxx': self.cxx_flags(),
|
||||
'nvcc': self.nvcc_flags()
|
||||
})
|
||||
return CUDAExtension(name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args={
|
||||
'cxx': self.strip_empty_entries(self.cxx_flags()),
|
||||
'nvcc': self.strip_empty_entries(self.nvcc_flags())
|
||||
})
|
||||
|
@@ -6,24 +6,22 @@ from .utils import append_nvcc_threads
|
||||
|
||||
class CPUAdamBuilder(Builder):
|
||||
NAME = "cpu_adam"
|
||||
BASE_DIR = "cuda_native"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam"
|
||||
|
||||
def __init__(self):
|
||||
self.name = CPUAdamBuilder.NAME
|
||||
super().__init__()
|
||||
|
||||
super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH)
|
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
# necessary 4 functions
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"),
|
||||
self.csrc_abs_path('cpu_adam.cpp'),
|
||||
]
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
return [
|
||||
self.colossalai_src_path(os.path.join(CPUAdamBuilder.BASE_DIR, "includes")),
|
||||
self.csrc_abs_path("includes"),
|
||||
self.get_cuda_home_include()
|
||||
]
|
||||
|
||||
@@ -36,7 +34,5 @@ class CPUAdamBuilder(Builder):
|
||||
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
|
||||
]
|
||||
|
||||
return append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags)
|
||||
|
||||
# necessary 4 functions
|
||||
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
||||
|
@@ -1,20 +1,19 @@
|
||||
import os
|
||||
|
||||
from .builder import Builder, get_cuda_cc_flag
|
||||
from .builder import Builder
|
||||
from .utils import get_cuda_cc_flag
|
||||
|
||||
|
||||
class FusedOptimBuilder(Builder):
|
||||
NAME = 'fused_optim'
|
||||
BASE_DIR = "cuda_native/csrc"
|
||||
NAME = "fused_optim"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim"
|
||||
|
||||
def __init__(self):
|
||||
self.name = FusedOptimBuilder.NAME
|
||||
super().__init__()
|
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.colossalai_src_path(os.path.join(FusedOptimBuilder.BASE_DIR, fname)) for fname in [
|
||||
self.csrc_abs_path(fname) for fname in [
|
||||
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
||||
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
|
||||
]
|
||||
@@ -22,12 +21,12 @@ class FusedOptimBuilder(Builder):
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_home_include()]
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
extra_cxx_flags = []
|
||||
return ['-O3'] + self.version_dependent_macros + extra_cxx_flags
|
||||
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
return ['-O3'] + version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = ['-lineinfo']
|
||||
|
29
op_builder/layernorm.py
Normal file
29
op_builder/layernorm.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from .builder import Builder
|
||||
from .utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class LayerNormBuilder(Builder):
|
||||
NAME = "layernorm"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.layernorm"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
def sources_files(self):
|
||||
ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ['-O3'] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = ['-maxrregcount=50']
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros
|
||||
return append_nvcc_threads(ret)
|
@@ -1,27 +1,30 @@
|
||||
import os
|
||||
|
||||
from .builder import Builder, get_cuda_cc_flag
|
||||
from .builder import Builder
|
||||
from .utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class MOEBuilder(Builder):
|
||||
|
||||
NAME = "moe"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.moe"
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir = "cuda_native/csrc"
|
||||
self.name = 'moe'
|
||||
super().__init__()
|
||||
super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
def include_dirs(self):
|
||||
ret = []
|
||||
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
|
||||
ret.append(os.path.join(self.base_dir, "kernels", "include"))
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
ret = [
|
||||
self.csrc_abs_path("kernels/include"),
|
||||
self.get_cuda_home_include()
|
||||
]
|
||||
return ret
|
||||
|
||||
def sources_files(self):
|
||||
ret = [os.path.join(self.base_dir, fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']]
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
return ['-O3'] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
@@ -30,4 +33,4 @@ class MOEBuilder(Builder):
|
||||
]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
|
||||
return ret
|
||||
return append_nvcc_threads(ret)
|
||||
|
@@ -1,32 +1,32 @@
|
||||
import os
|
||||
|
||||
from .builder import Builder, get_cuda_cc_flag
|
||||
from .builder import Builder
|
||||
from .utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class MultiHeadAttnBuilder(Builder):
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir = "cuda_native/csrc"
|
||||
self.name = 'multihead_attention'
|
||||
super().__init__()
|
||||
NAME = "multihead_attention"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention"
|
||||
|
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
def __init__(self):
|
||||
super().__init__(name=MultiHeadAttnBuilder.NAME,
|
||||
prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
|
||||
def include_dirs(self):
|
||||
ret = []
|
||||
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
|
||||
ret.append(os.path.join(self.base_dir, "kernels", "include"))
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
os.path.join(self.base_dir, fname) for fname in [
|
||||
self.csrc_abs_path(fname) for fname in [
|
||||
'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu',
|
||||
'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu',
|
||||
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
|
||||
]
|
||||
]
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ['-O3'] + self.version_dependent_macros
|
||||
@@ -37,5 +37,5 @@ class MultiHeadAttnBuilder(Builder):
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
|
||||
]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
|
||||
return ret
|
||||
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
||||
|
37
op_builder/scaled_masked_softmax.py
Normal file
37
op_builder/scaled_masked_softmax.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
|
||||
from .builder import Builder
|
||||
from .utils import append_nvcc_threads
|
||||
|
||||
|
||||
class ScaledMaskedSoftmaxBuilder(Builder):
|
||||
NAME = "scaled_masked_softmax"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
# necessary 4 functions
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname) for fname in
|
||||
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu']
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
return [
|
||||
self.csrc_abs_path("kernels/include"),
|
||||
self.get_cuda_home_include()
|
||||
]
|
||||
|
||||
def cxx_flags(self):
|
||||
return ['-O3'] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
|
||||
]
|
||||
ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
@@ -1,36 +0,0 @@
|
||||
import os
|
||||
|
||||
from .builder import Builder, get_cuda_cc_flag
|
||||
|
||||
|
||||
class ScaledSoftmaxBuilder(Builder):
|
||||
|
||||
def __init__(self):
|
||||
self.base_dir = "cuda_native/csrc"
|
||||
self.name = 'scaled_upper_triang_masked_softmax'
|
||||
super().__init__()
|
||||
|
||||
def include_dirs(self):
|
||||
ret = []
|
||||
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
|
||||
ret.append(os.path.join(self.base_dir, "kernels", "include"))
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
os.path.join(self.base_dir, fname)
|
||||
for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu']
|
||||
]
|
||||
return [self.colossalai_src_path(path) for path in ret]
|
||||
|
||||
def cxx_flags(self):
|
||||
return ['-O3']
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
|
||||
'--expt-extended-lambda'
|
||||
]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
|
||||
return ret
|
37
op_builder/scaled_upper_triangle_masked_softmax.py
Normal file
37
op_builder/scaled_upper_triangle_masked_softmax.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
|
||||
from .builder import Builder
|
||||
from .utils import append_nvcc_threads, get_cuda_cc_flag
|
||||
|
||||
|
||||
class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder):
|
||||
NAME = "scaled_upper_triangle_masked_softmax"
|
||||
PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH)
|
||||
|
||||
def include_dirs(self):
|
||||
return [
|
||||
self.csrc_abs_path("kernels/include"),
|
||||
self.get_cuda_home_include()
|
||||
]
|
||||
|
||||
def sources_files(self):
|
||||
ret = [
|
||||
self.csrc_abs_path(fname)
|
||||
for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu']
|
||||
]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
return ['-O3'] + self.version_dependent_macros
|
||||
|
||||
def nvcc_flags(self):
|
||||
extra_cuda_flags = [
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr',
|
||||
'--expt-extended-lambda'
|
||||
]
|
||||
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
|
||||
return append_nvcc_threads(ret)
|
@@ -1,4 +1,6 @@
|
||||
import re
|
||||
import subprocess
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
@@ -11,6 +13,26 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
def get_cuda_cc_flag() -> List:
|
||||
"""get_cuda_cc_flag
|
||||
|
||||
cc flag for your GPU arch
|
||||
"""
|
||||
|
||||
# only import torch when needed
|
||||
# this is to avoid importing torch when building on a machine without torch pre-installed
|
||||
# one case is to build wheel for pypi release
|
||||
import torch
|
||||
|
||||
cc_flag = []
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r'sm_(\d+)', arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60:
|
||||
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||
|
||||
return cc_flag
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
Reference in New Issue
Block a user