[kernel] cached the op kernel and fixed version check (#2886)

* [kernel] cached the op kernel and fixed version check

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-03-03 21:45:05 +08:00
committed by GitHub
parent 0ff8406b00
commit 3a5d93bc2c
4 changed files with 325 additions and 137 deletions

View File

@@ -5,22 +5,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import List
def print_rank_0(message):
"""
Print on only one process to avoid spamming.
"""
try:
import torch.distributed as dist
if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True
if is_main_rank:
print(message)
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
class Builder(ABC):
@@ -37,6 +22,9 @@ class Builder(ABC):
self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
# we store the op as an attribute to avoid repeated building and loading
self.cached_op_module = None
assert prebuilt_import_path.startswith('colossalai._C'), \
f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}'
@@ -117,6 +105,35 @@ class Builder(ABC):
"""
return importlib.import_module(self.prebuilt_import_path)
def check_runtime_build_environment(self):
"""
Check whether the system environment is ready for extension compilation.
"""
try:
import torch
from torch.utils.cpp_extension import CUDA_HOME
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
CUDA_HOME = None
if not TORCH_AVAILABLE:
raise ModuleNotFoundError(
"PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions")
if CUDA_HOME is None:
raise RuntimeError(
"CUDA_HOME is not found. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions"
)
# make sure CUDA is available for compilation during
cuda_available = check_cuda_availability()
if not cuda_available:
raise RuntimeError("CUDA is not available on your system as torch.cuda.is_avaible() returns False.")
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
check_system_pytorch_cuda_match(CUDA_HOME)
def load(self, verbose=True):
"""
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
@@ -128,16 +145,27 @@ class Builder(ABC):
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
from torch.utils.cpp_extension import load
start_build = time.time()
# if the kernel has be compiled and cached, we directly use it
if self.cached_op_module is not None:
return self.cached_op_module
try:
# if the kernel has been pre-built during installation
# we just directly import it
op_module = self.import_op()
if verbose:
print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.")
print_rank_0(
f"[extension] OP {self.prebuilt_import_path} has been compileed ahead of time, skip building.")
except ImportError:
# check environment
self.check_runtime_build_environment()
# time the kernel compilation
start_build = time.time()
# construct the build directory
import torch
from torch.utils.cpp_extension import load
torch_version_major = torch.__version__.split('.')[0]
torch_version_minor = torch.__version__.split('.')[1]
torch_cuda_version = torch.version.cuda
@@ -147,11 +175,7 @@ class Builder(ABC):
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
print_rank_0(
"=========================================================================================")
print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print_rank_0(
"=========================================================================================")
print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now")
# load the kernel
op_module = load(name=self.name,
@@ -163,9 +187,14 @@ class Builder(ABC):
build_directory=build_directory,
verbose=verbose)
build_duration = time.time() - start_build
if verbose:
print_rank_0(f"Time to load {self.name} op: {build_duration} seconds")
build_duration = time.time() - start_build
# log jit compilation time
if verbose:
print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds")
# cache the built/loaded kernel
self.cached_op_module = op_module
return op_module