[setup] support pre-build and jit-build of cuda kernels (#2374)

* [setup] support pre-build and jit-build of cuda kernels

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-01-06 20:50:26 +08:00
committed by GitHub
parent 12c8bf38d7
commit 40d376c566
36 changed files with 414 additions and 390 deletions

View File

@@ -1,40 +1,49 @@
import importlib
import os
import re
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
def get_cuda_cc_flag() -> List:
"""get_cuda_cc_flag
class Builder(ABC):
"""
Builder is the base class to build extensions for PyTorch.
cc flag for your GPU arch
Args:
name (str): the name of the kernel to be built
prebuilt_import_path (str): the path where the extension is installed during pip install
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
def __init__(self, name: str, prebuilt_import_path: str):
self.name = name
self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
return cc_flag
assert prebuilt_import_path.startswith('colossalai._C'), \
f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}'
def relative_to_abs_path(self, code_path: str) -> str:
"""
This function takes in a path relative to the colossalai root directory and return the absolute path.
"""
op_builder_module_path = Path(__file__).parent
class Builder(object):
def colossalai_src_path(self, code_path):
current_file_path = Path(__file__)
if os.path.islink(current_file_path.parent):
# symbolic link
return os.path.join(current_file_path.parent.parent.absolute(), code_path)
# if we install from source
# the current file path will be op_builder/builder.py
# if we install via pip install colossalai
# the current file path will be colossalai/kernel/op_builder/builder.py
# this is because that the op_builder inside colossalai is a symlink
# this symlink will be replaced with actual files if we install via pypi
# thus we cannot tell the colossalai root directory by checking whether the op_builder
# is a symlink, we can only tell whether it is inside or outside colossalai
if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'):
root_path = op_builder_module_path.parent.parent
else:
return os.path.join(current_file_path.parent.parent.absolute(), "colossalai", "kernel", code_path)
root_path = op_builder_module_path.parent.joinpath('colossalai')
code_abs_path = root_path.joinpath(code_path)
return str(code_abs_path)
def get_cuda_home_include(self):
"""
@@ -46,47 +55,94 @@ class Builder(object):
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def csrc_abs_path(self, path):
return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path)
# functions must be overrided begin
def sources_files(self):
@abstractmethod
def sources_files(self) -> List[str]:
"""
This function should return a list of source files for extensions.
"""
raise NotImplementedError
def include_dirs(self):
raise NotImplementedError
@abstractmethod
def include_dirs(self) -> List[str]:
"""
This function should return a list of inlcude files for extensions.
"""
pass
def cxx_flags(self):
raise NotImplementedError
@abstractmethod
def cxx_flags(self) -> List[str]:
"""
This function should return a list of cxx compilation flags for extensions.
"""
pass
def nvcc_flags(self):
raise NotImplementedError
@abstractmethod
def nvcc_flags(self) -> List[str]:
"""
This function should return a list of nvcc compilation flags for extensions.
"""
pass
# functions must be overrided over
def strip_empty_entries(self, args):
'''
Drop any empty strings from the list of compile and link flags
'''
return [x for x in args if len(x) > 0]
def import_op(self):
"""
This function will import the op module by its string name.
"""
return importlib.import_module(self.prebuilt_import_path)
def load(self, verbose=True):
"""
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
kernel is built during pip install, it can be accessed through `colossalai._C`.
load and compile cpu_adam lib at runtime
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
import time
from torch.utils.cpp_extension import load
start_build = time.time()
op_module = load(name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
verbose=verbose)
try:
op_module = self.import_op()
if verbose:
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
except ImportError:
# construct the build directory
import torch
torch_version_major = torch.__version__.split('.')[0]
torch_version_minor = torch.__version__.split('.')[1]
torch_cuda_version = torch.version.cuda
home_directory = os.path.expanduser('~')
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
build_directory = os.path.join(home_directory, extension_directory)
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
print("=========================================================================================")
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print("=========================================================================================")
# load the kernel
op_module = load(name=self.name,
sources=self.strip_empty_entries(self.sources_files()),
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
extra_cflags=self.cxx_flags(),
extra_cuda_cflags=self.nvcc_flags(),
extra_ldflags=[],
build_directory=build_directory,
verbose=verbose)
build_duration = time.time() - start_build
if verbose:
@@ -94,17 +150,16 @@ class Builder(object):
return op_module
def builder(self, name) -> 'CUDAExtension':
def builder(self) -> 'CUDAExtension':
"""
get a CUDAExtension instance used for setup.py
"""
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name=name,
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources_files()],
include_dirs=self.include_dirs(),
extra_compile_args={
'cxx': self.cxx_flags(),
'nvcc': self.nvcc_flags()
})
return CUDAExtension(name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args={
'cxx': self.strip_empty_entries(self.cxx_flags()),
'nvcc': self.strip_empty_entries(self.nvcc_flags())
})