mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[builder] raise Error when CUDA_HOME is not set (#2213)
This commit is contained in:
@@ -30,6 +30,13 @@ class Builder(object):
|
||||
else:
|
||||
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
|
||||
|
||||
def get_cuda_include(self):
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return cuda_include
|
||||
|
||||
def strip_empty_entries(self, args):
|
||||
'''
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
|
@@ -27,9 +27,7 @@ class CPUAdamBuilder(Builder):
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), cuda_include]
|
||||
return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), self.get_cuda_include()]
|
||||
|
||||
def strip_empty_entries(self, args):
|
||||
'''
|
||||
|
@@ -31,10 +31,7 @@ class FusedOptimBuilder(Builder):
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), cuda_include]
|
||||
return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_include()]
|
||||
|
||||
def builder(self, name):
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
|
@@ -31,10 +31,8 @@ class MultiHeadAttnBuilder(Builder):
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
ret = []
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
ret = [os.path.join(self.base_dir, "includes"), cuda_include]
|
||||
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_include()]
|
||||
ret.append(os.path.join(self.base_dir, "kernels", "include"))
|
||||
print("include_paths", ret)
|
||||
return ret
|
||||
|
Reference in New Issue
Block a user