[kernel] cached the op kernel and fixed version check (#2886)

* [kernel] cached the op kernel and fixed version check

* polish code

* polish code
This commit is contained in:
Frank Lee 2023-03-03 21:45:05 +08:00 committed by GitHub
parent 0ff8406b00
commit 3a5d93bc2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 325 additions and 137 deletions

View File

@ -15,17 +15,18 @@ Method 2 is good because it allows the user to only build the kernel they actual
## PyTorch Extensions in Colossal-AI ## PyTorch Extensions in Colossal-AI
As mentioned in the section above, our aim is to make these two methods coherently supported in Colossal-AI, meaning that for a kernel should be either built in `setup.py` or during runtime. The project DeepSpeed (https://github.com/microsoft/DeepSpeed) has proposed a [solution](https://github.com/microsoft/DeepSpeed/tree/master/op_builder)) to support kernel-build during either installation or runtime.
There are mainly two functions used to build extensions. We have adapted from DeepSpeed's solution to build extensions. The extension build requries two main functions from PyTorch:
1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. 1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`.
2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime 2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime
Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong).
We have implemented the following conventions: Based on the DeepSpeed's work, we have make several modifications and improvements:
1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` 1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C`
2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) 2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete)
3. Once a kernel is loaded, we will cache it in the builder to avoid repeated kernel loading.
When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered.

View File

@ -5,22 +5,7 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
def print_rank_0(message):
"""
Print on only one process to avoid spamming.
"""
try:
import torch.distributed as dist
if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True
if is_main_rank:
print(message)
class Builder(ABC): class Builder(ABC):
@ -37,6 +22,9 @@ class Builder(ABC):
self.prebuilt_import_path = prebuilt_import_path self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
# we store the op as an attribute to avoid repeated building and loading
self.cached_op_module = None
assert prebuilt_import_path.startswith('colossalai._C'), \ assert prebuilt_import_path.startswith('colossalai._C'), \
f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}' f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}'
@ -117,6 +105,35 @@ class Builder(ABC):
""" """
return importlib.import_module(self.prebuilt_import_path) return importlib.import_module(self.prebuilt_import_path)
def check_runtime_build_environment(self):
"""
Check whether the system environment is ready for extension compilation.
"""
try:
import torch
from torch.utils.cpp_extension import CUDA_HOME
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
CUDA_HOME = None
if not TORCH_AVAILABLE:
raise ModuleNotFoundError(
"PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions")
if CUDA_HOME is None:
raise RuntimeError(
"CUDA_HOME is not found. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions"
)
# make sure CUDA is available for compilation during
cuda_available = check_cuda_availability()
if not cuda_available:
raise RuntimeError("CUDA is not available on your system as torch.cuda.is_avaible() returns False.")
# make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
check_system_pytorch_cuda_match(CUDA_HOME)
def load(self, verbose=True): def load(self, verbose=True):
""" """
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
@ -128,16 +145,27 @@ class Builder(ABC):
Args: Args:
verbose (bool, optional): show detailed info. Defaults to True. verbose (bool, optional): show detailed info. Defaults to True.
""" """
from torch.utils.cpp_extension import load # if the kernel has be compiled and cached, we directly use it
start_build = time.time() if self.cached_op_module is not None:
return self.cached_op_module
try: try:
# if the kernel has been pre-built during installation
# we just directly import it
op_module = self.import_op() op_module = self.import_op()
if verbose: if verbose:
print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.") print_rank_0(
f"[extension] OP {self.prebuilt_import_path} has been compileed ahead of time, skip building.")
except ImportError: except ImportError:
# check environment
self.check_runtime_build_environment()
# time the kernel compilation
start_build = time.time()
# construct the build directory # construct the build directory
import torch import torch
from torch.utils.cpp_extension import load
torch_version_major = torch.__version__.split('.')[0] torch_version_major = torch.__version__.split('.')[0]
torch_version_minor = torch.__version__.split('.')[1] torch_version_minor = torch.__version__.split('.')[1]
torch_cuda_version = torch.version.cuda torch_cuda_version = torch.version.cuda
@ -147,11 +175,7 @@ class Builder(ABC):
Path(build_directory).mkdir(parents=True, exist_ok=True) Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose: if verbose:
print_rank_0( print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now")
"=========================================================================================")
print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
print_rank_0(
"=========================================================================================")
# load the kernel # load the kernel
op_module = load(name=self.name, op_module = load(name=self.name,
@ -163,9 +187,14 @@ class Builder(ABC):
build_directory=build_directory, build_directory=build_directory,
verbose=verbose) verbose=verbose)
build_duration = time.time() - start_build build_duration = time.time() - start_build
if verbose:
print_rank_0(f"Time to load {self.name} op: {build_duration} seconds") # log jit compilation time
if verbose:
print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds")
# cache the built/loaded kernel
self.cached_op_module = op_module
return op_module return op_module

View File

@ -1,29 +1,203 @@
import os
import re import re
import subprocess import subprocess
import warnings
from typing import List from typing import List
def get_cuda_bare_metal_version(cuda_dir): def print_rank_0(message: str) -> None:
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) """
output = raw_output.split() Print on only one process to avoid spamming.
release_idx = output.index("release") + 1 """
release = output[release_idx].split(".") try:
bare_metal_major = release[0] import torch.distributed as dist
bare_metal_minor = release[1][0] if not dist.is_initialized():
is_main_rank = True
else:
is_main_rank = dist.get_rank() == 0
except ImportError:
is_main_rank = True
return raw_output, bare_metal_major, bare_metal_minor if is_main_rank:
print(message)
def get_cuda_cc_flag() -> List:
"""get_cuda_cc_flag
cc flag for your GPU arch def get_cuda_version_in_pytorch() -> List[int]:
"""
This function returns the CUDA version in the PyTorch build.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
import torch
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
except:
raise ValueError(
"[extension] Cannot retrive the CUDA version in the PyTorch binary given by torch.version.cuda")
return torch_cuda_major, torch_cuda_minor
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
"""
Get the System CUDA version from nvcc.
Args:
cuda_dir (str): the directory for CUDA Toolkit.
Returns:
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
nvcc_path = os.path.join(cuda_dir, 'bin/nvcc')
if cuda_dir is None:
raise ValueError(
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
)
# check for nvcc path
if not os.path.exists(nvcc_path):
raise FileNotFoundError(
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
)
# parse the nvcc -v output to obtain the system cuda version
try:
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
except:
raise ValueError(
f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
)
return bare_metal_major, bare_metal_minor
def check_system_pytorch_cuda_match(cuda_dir):
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
if bare_metal_major != torch_cuda_major:
raise Exception(
f'[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) '
f'mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor}).'
'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .'
)
print(bare_metal_minor != torch_cuda_minor)
if bare_metal_minor != torch_cuda_minor:
warnings.warn(
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
)
return True
def get_pytorch_version() -> List[int]:
"""
This functions finds the PyTorch version.
Returns:
A tuple of integers in the form of (major, minor, patch).
"""
import torch
torch_version = torch.__version__.split('+')[0]
TORCH_MAJOR = int(torch_version.split('.')[0])
TORCH_MINOR = int(torch_version.split('.')[1])
TORCH_PATCH = int(torch_version.split('.')[2])
return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
"""
Compare the current PyTorch version with the minium required version.
Args:
min_major_version (int): the minimum major version of PyTorch required
min_minor_version (int): the minimum minor version of PyTorch required
Returns:
A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
"""
# get pytorch version
torch_major, torch_minor, _ = get_pytorch_version()
# if the
if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
raise RuntimeError(
f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
"The latest stable release can be obtained from https://pytorch.org/get-started/locally/")
def check_cuda_availability():
"""
Check if CUDA is available on the system.
Returns:
A boolean value. True if CUDA is available and False otherwise.
"""
import torch
return torch.cuda.is_available()
def set_cuda_arch_list(cuda_dir):
"""
This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'.
"""
cuda_available = check_cuda_availability()
# we only need to set this when CUDA is not available for cross-compilation
if not cuda_available:
warnings.warn(
'\n[extension] PyTorch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n'
'By default, Colossal-AI will cross-compile for \n'
'1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
'2. Volta (compute capability 7.0)\n'
'3. Turing (compute capability 7.5),\n'
'4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n'
'\nIf you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
arch_list = ['6.0', '6.1', '6.2', '7.0', '7.5']
if int(bare_metal_major) == 11:
if int(bare_metal_minor) == 0:
arch_list.append('8.0')
else:
arch_list.append('8.0')
arch_list.append('8.6')
arch_list_str = ';'.join(arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
return False
return True
def get_cuda_cc_flag() -> List[str]:
"""
This function produces the cc flags for your GPU arch
Returns:
The CUDA cc flags for compilation.
""" """
# only import torch when needed # only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed # this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release # one case is to build wheel for pypi release
import torch import torch
cc_flag = [] cc_flag = []
for arch in torch.cuda.get_arch_list(): for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch) res = re.search(r'sm_(\d+)', arch)
@ -31,12 +205,19 @@ def get_cuda_cc_flag() -> List:
arch_cap = res[1] arch_cap = res[1]
if int(arch_cap) >= 60: if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
return cc_flag return cc_flag
def append_nvcc_threads(nvcc_extra_args):
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
"""
This function appends the threads flag to your nvcc args.
Returns:
The nvcc compilation flags including the threads flag.
"""
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args return nvcc_extra_args

161
setup.py
View File

@ -1,115 +1,87 @@
import os import os
import re
from datetime import datetime from datetime import datetime
from typing import List
from setuptools import find_packages, setup from setuptools import find_packages, setup
from op_builder.utils import get_cuda_bare_metal_version from op_builder.utils import (
check_cuda_availability,
check_pytorch_version,
check_system_pytorch_cuda_match,
get_cuda_bare_metal_version,
get_pytorch_version,
set_cuda_arch_list,
)
try: try:
import torch import torch
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension from torch.utils.cpp_extension import CUDA_HOME, BuildExtension
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 10):
raise RuntimeError("Colossal-AI requires Pytorch 1.10 or newer.\n"
"The latest stable release can be obtained from https://pytorch.org/")
TORCH_AVAILABLE = True TORCH_AVAILABLE = True
except ImportError: except ImportError:
TORCH_AVAILABLE = False TORCH_AVAILABLE = False
CUDA_HOME = None CUDA_HOME = None
# ninja build does not work unless include_dirs are abs path # Some constants for installation checks
this_dir = os.path.dirname(os.path.abspath(__file__)) MIN_PYTORCH_VERSION_MAJOR = 1
build_cuda_ext = False MIN_PYTORCH_VERSION_MINOR = 10
ext_modules = [] THIS_DIR = os.path.dirname(os.path.abspath(__file__))
is_nightly = int(os.environ.get('NIGHTLY', '0')) == 1 BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1
IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1
if int(os.environ.get('CUDA_EXT', '0')) == 1: # a variable to store the op builder
ext_modules = []
# check for CUDA extension dependencies
def environment_check_for_cuda_extension_build():
if not TORCH_AVAILABLE: if not TORCH_AVAILABLE:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions"
) )
if not CUDA_HOME: if not CUDA_HOME:
raise RuntimeError( raise RuntimeError(
"CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions" "[extension] CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions"
) )
build_cuda_ext = True check_system_pytorch_cuda_match(CUDA_HOME)
check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
check_cuda_availability()
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def fetch_requirements(path) -> List[str]:
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) """
torch_binary_major = torch.version.cuda.split(".")[0] This function reads the requirements file.
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with") Args:
print(raw_output + "from " + cuda_dir + "/bin\n") path (str): the path to the requirements file.
if bare_metal_major != torch_binary_major: Returns:
print(f'The detected CUDA version ({raw_output}) mismatches the version that was used to compile PyTorch ' The lines in the requirements file.
f'({torch.version.cuda}). CUDA extension will not be installed.') """
return False
if bare_metal_minor != torch_binary_minor:
print("\nWarning: Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
f"Pytorch binaries were compiled with Cuda {torch.version.cuda}.\n"
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. ")
return True
def check_cuda_availability(cuda_dir):
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query
# torch.cuda.get_device_capability(), which will fail if you are compiling in an environment
# without visible GPUs (e.g. during an nvidia-docker build command).
print(
'\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n'
'By default, Colossal-AI will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cuda_dir)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
return False
if cuda_dir is None:
print("nvcc was not found. CUDA extension will not be installed. If you're installing within a container from "
"https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
return False
return True
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def fetch_requirements(path):
with open(path, 'r') as fd: with open(path, 'r') as fd:
return [r.strip() for r in fd.readlines()] return [r.strip() for r in fd.readlines()]
def fetch_readme(): def fetch_readme() -> str:
"""
This function reads the README.md file in the current directory.
Returns:
The lines in the README file.
"""
with open('README.md', encoding='utf-8') as f: with open('README.md', encoding='utf-8') as f:
return f.read() return f.read()
def get_version(): def get_version() -> str:
"""
This function reads the version.txt and generates the colossalai/version.py file.
Returns:
The library version stored in version.txt.
"""
setup_file_path = os.path.abspath(__file__) setup_file_path = os.path.abspath(__file__)
project_path = os.path.dirname(setup_file_path) project_path = os.path.dirname(setup_file_path)
version_txt_path = os.path.join(project_path, 'version.txt') version_txt_path = os.path.join(project_path, 'version.txt')
@ -121,13 +93,17 @@ def get_version():
# write version into version.py # write version into version.py
with open(version_py_path, 'w') as f: with open(version_py_path, 'w') as f:
f.write(f"__version__ = '{version}'\n") f.write(f"__version__ = '{version}'\n")
if build_cuda_ext:
torch_version = '.'.join(torch.__version__.split('.')[:2]) # look for pytorch and cuda version
cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)[1:]) if BUILD_CUDA_EXT:
torch_major, torch_minor, _ = get_pytorch_version()
torch_version = f'{torch_major}.{torch_minor}'
cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME))
else: else:
torch_version = None torch_version = None
cuda_version = None cuda_version = None
# write the version into the python file
if torch_version: if torch_version:
f.write(f'torch = "{torch_version}"\n') f.write(f'torch = "{torch_version}"\n')
else: else:
@ -141,25 +117,26 @@ def get_version():
return version return version
if build_cuda_ext: if BUILD_CUDA_EXT:
build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) environment_check_for_cuda_extension_build()
set_cuda_arch_list(CUDA_HOME)
if build_cuda_ext:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
from op_builder import ALL_OPS from op_builder import ALL_OPS
op_names = []
# load all builders
for name, builder_cls in ALL_OPS.items(): for name, builder_cls in ALL_OPS.items():
print(f'===== Building Extension {name} =====') op_names.append(name)
ext_modules.append(builder_cls().builder()) ext_modules.append(builder_cls().builder())
# show log
op_name_list = ', '.join(op_names)
print(f"[extension] loaded builders for {op_name_list}")
# always put not nightly branch as the if branch # always put not nightly branch as the if branch
# otherwise github will treat colossalai-nightly as the project name # otherwise github will treat colossalai-nightly as the project name
# and it will mess up with the dependency graph insights # and it will mess up with the dependency graph insights
if not is_nightly: if not IS_NIGHTLY:
version = get_version() version = get_version()
package_name = 'colossalai' package_name = 'colossalai'
else: else: