[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)

* refactor compilation mechanism and unified multi hw

* fix file path bug

* add init.py to make pybind a module to avoid relative path error caused by softlink

* delete duplicated micros

* fix micros bug in gcc
This commit is contained in:
傅剑寒
2024-04-24 14:17:54 +08:00
committed by GitHub
parent 04863a9b14
commit 279300dc5f
64 changed files with 345 additions and 310 deletions

View File

@@ -21,6 +21,7 @@ class _CudaExtension(_CppExtension):
"""
This function should return a list of nvcc compilation flags for extensions.
"""
return ["-DCOLOSSAL_WITH_CUDA"]
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
@@ -53,6 +54,12 @@ class _CudaExtension(_CppExtension):
cuda_include = os.path.join(CUDA_HOME, "include")
return cuda_include
def include_dirs(self) -> List[str]:
"""
This function should return a list of include files for extensions.
"""
return super().include_dirs() + [self.get_cuda_home_include()]
def build_jit(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME, load