mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[setup] support pre-build and jit-build of cuda kernels (#2374)
* [setup] support pre-build and jit-build of cuda kernels * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -1,40 +1,49 @@
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_cuda_cc_flag() -> List:
|
||||
"""get_cuda_cc_flag
|
||||
class Builder(ABC):
|
||||
"""
|
||||
Builder is the base class to build extensions for PyTorch.
|
||||
|
||||
cc flag for your GPU arch
|
||||
Args:
|
||||
name (str): the name of the kernel to be built
|
||||
prebuilt_import_path (str): the path where the extension is installed during pip install
|
||||
"""
|
||||
|
||||
# only import torch when needed
|
||||
# this is to avoid importing torch when building on a machine without torch pre-installed
|
||||
# one case is to build wheel for pypi release
|
||||
import torch
|
||||
|
||||
cc_flag = []
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r'sm_(\d+)', arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60:
|
||||
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||
def __init__(self, name: str, prebuilt_import_path: str):
|
||||
self.name = name
|
||||
self.prebuilt_import_path = prebuilt_import_path
|
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
return cc_flag
|
||||
assert prebuilt_import_path.startswith('colossalai._C'), \
|
||||
f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}'
|
||||
|
||||
def relative_to_abs_path(self, code_path: str) -> str:
|
||||
"""
|
||||
This function takes in a path relative to the colossalai root directory and return the absolute path.
|
||||
"""
|
||||
op_builder_module_path = Path(__file__).parent
|
||||
|
||||
class Builder(object):
|
||||
|
||||
def colossalai_src_path(self, code_path):
|
||||
current_file_path = Path(__file__)
|
||||
if os.path.islink(current_file_path.parent):
|
||||
# symbolic link
|
||||
return os.path.join(current_file_path.parent.parent.absolute(), code_path)
|
||||
# if we install from source
|
||||
# the current file path will be op_builder/builder.py
|
||||
# if we install via pip install colossalai
|
||||
# the current file path will be colossalai/kernel/op_builder/builder.py
|
||||
# this is because that the op_builder inside colossalai is a symlink
|
||||
# this symlink will be replaced with actual files if we install via pypi
|
||||
# thus we cannot tell the colossalai root directory by checking whether the op_builder
|
||||
# is a symlink, we can only tell whether it is inside or outside colossalai
|
||||
if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'):
|
||||
root_path = op_builder_module_path.parent.parent
|
||||
else:
|
||||
return os.path.join(current_file_path.parent.parent.absolute(), "colossalai", "kernel", code_path)
|
||||
root_path = op_builder_module_path.parent.joinpath('colossalai')
|
||||
|
||||
code_abs_path = root_path.joinpath(code_path)
|
||||
return str(code_abs_path)
|
||||
|
||||
def get_cuda_home_include(self):
|
||||
"""
|
||||
@@ -46,47 +55,94 @@ class Builder(object):
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return cuda_include
|
||||
|
||||
def csrc_abs_path(self, path):
|
||||
return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path)
|
||||
|
||||
# functions must be overrided begin
|
||||
def sources_files(self):
|
||||
@abstractmethod
|
||||
def sources_files(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of source files for extensions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def include_dirs(self):
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def include_dirs(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of inlcude files for extensions.
|
||||
"""
|
||||
pass
|
||||
|
||||
def cxx_flags(self):
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def cxx_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of cxx compilation flags for extensions.
|
||||
"""
|
||||
pass
|
||||
|
||||
def nvcc_flags(self):
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def nvcc_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of nvcc compilation flags for extensions.
|
||||
"""
|
||||
pass
|
||||
|
||||
# functions must be overrided over
|
||||
|
||||
def strip_empty_entries(self, args):
|
||||
'''
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
'''
|
||||
return [x for x in args if len(x) > 0]
|
||||
|
||||
def import_op(self):
|
||||
"""
|
||||
This function will import the op module by its string name.
|
||||
"""
|
||||
return importlib.import_module(self.prebuilt_import_path)
|
||||
|
||||
def load(self, verbose=True):
|
||||
"""
|
||||
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
|
||||
If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the
|
||||
kernel is built during pip install, it can be accessed through `colossalai._C`.
|
||||
|
||||
load and compile cpu_adam lib at runtime
|
||||
Warning: do not load this kernel repeatedly during model execution as it could slow down the training process.
|
||||
|
||||
Args:
|
||||
verbose (bool, optional): show detailed info. Defaults to True.
|
||||
"""
|
||||
import time
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
start_build = time.time()
|
||||
|
||||
op_module = load(name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_cuda_cflags=self.nvcc_flags(),
|
||||
extra_ldflags=[],
|
||||
verbose=verbose)
|
||||
try:
|
||||
op_module = self.import_op()
|
||||
if verbose:
|
||||
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
||||
except ImportError:
|
||||
# construct the build directory
|
||||
import torch
|
||||
torch_version_major = torch.__version__.split('.')[0]
|
||||
torch_version_minor = torch.__version__.split('.')[1]
|
||||
torch_cuda_version = torch.version.cuda
|
||||
home_directory = os.path.expanduser('~')
|
||||
extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}"
|
||||
build_directory = os.path.join(home_directory, extension_directory)
|
||||
Path(build_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print("=========================================================================================")
|
||||
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
||||
print("=========================================================================================")
|
||||
|
||||
# load the kernel
|
||||
op_module = load(name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_cuda_cflags=self.nvcc_flags(),
|
||||
extra_ldflags=[],
|
||||
build_directory=build_directory,
|
||||
verbose=verbose)
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
@@ -94,17 +150,16 @@ class Builder(object):
|
||||
|
||||
return op_module
|
||||
|
||||
def builder(self, name) -> 'CUDAExtension':
|
||||
def builder(self) -> 'CUDAExtension':
|
||||
"""
|
||||
get a CUDAExtension instance used for setup.py
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
|
||||
return CUDAExtension(
|
||||
name=name,
|
||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources_files()],
|
||||
include_dirs=self.include_dirs(),
|
||||
extra_compile_args={
|
||||
'cxx': self.cxx_flags(),
|
||||
'nvcc': self.nvcc_flags()
|
||||
})
|
||||
return CUDAExtension(name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args={
|
||||
'cxx': self.strip_empty_entries(self.cxx_flags()),
|
||||
'nvcc': self.strip_empty_entries(self.nvcc_flags())
|
||||
})
|
||||
|
Reference in New Issue
Block a user