mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
106
extensions/cuda_extension.py
Normal file
106
extensions/cuda_extension.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import List
|
||||
|
||||
from .cpp_extension import _CppExtension
|
||||
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
|
||||
|
||||
__all__ = ["_CudaExtension"]
|
||||
|
||||
# Some constants for installation checks
|
||||
MIN_PYTORCH_VERSION_MAJOR = 1
|
||||
MIN_PYTORCH_VERSION_MINOR = 10
|
||||
|
||||
|
||||
class _CudaExtension(_CppExtension):
|
||||
@abstractmethod
|
||||
def nvcc_flags(self) -> List[str]:
|
||||
"""
|
||||
This function should return a list of nvcc compilation flags for extensions.
|
||||
"""
|
||||
|
||||
def is_hardware_available(self) -> bool:
|
||||
# cuda extension can only be built if cuda is availabe
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = torch.cuda.is_available()
|
||||
except:
|
||||
cuda_available = False
|
||||
return cuda_available
|
||||
|
||||
def assert_hardware_compatible(self) -> None:
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
if not CUDA_HOME:
|
||||
raise AssertionError(
|
||||
"[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
|
||||
)
|
||||
check_system_pytorch_cuda_match(CUDA_HOME)
|
||||
check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
|
||||
|
||||
def get_cuda_home_include(self):
|
||||
"""
|
||||
return include path inside the cuda home.
|
||||
"""
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return cuda_include
|
||||
|
||||
def build_jit(self) -> None:
|
||||
from torch.utils.cpp_extension import CUDA_HOME, load
|
||||
|
||||
set_cuda_arch_list(CUDA_HOME)
|
||||
|
||||
# get build dir
|
||||
build_directory = _Extension.get_jit_extension_folder_path()
|
||||
build_directory = Path(build_directory)
|
||||
build_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# check if the kernel has been built
|
||||
compiled_before = False
|
||||
kernel_file_path = build_directory.joinpath(f"{self.name}.o")
|
||||
if kernel_file_path.exists():
|
||||
compiled_before = True
|
||||
|
||||
# load the kernel
|
||||
if compiled_before:
|
||||
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
|
||||
else:
|
||||
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")
|
||||
|
||||
build_start = time.time()
|
||||
op_kernel = load(
|
||||
name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_cflags=self.cxx_flags(),
|
||||
extra_cuda_cflags=self.nvcc_flags(),
|
||||
extra_ldflags=[],
|
||||
build_directory=str(build_directory),
|
||||
)
|
||||
build_duration = time.time() - build_start
|
||||
|
||||
if compiled_before:
|
||||
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
|
||||
else:
|
||||
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_kernel
|
||||
|
||||
def build_aot(self) -> "CUDAExtension":
|
||||
from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
|
||||
|
||||
set_cuda_arch_list(CUDA_HOME)
|
||||
return CUDAExtension(
|
||||
name=self.prebuilt_import_path,
|
||||
sources=self.strip_empty_entries(self.sources_files()),
|
||||
include_dirs=self.strip_empty_entries(self.include_dirs()),
|
||||
extra_compile_args={
|
||||
"cxx": self.strip_empty_entries(self.cxx_flags()),
|
||||
"nvcc": self.strip_empty_entries(self.nvcc_flags()),
|
||||
},
|
||||
)
|
Reference in New Issue
Block a user