[setup] support pre-build and jit-build of cuda kernels (#2374)

* [setup] support pre-build and jit-build of cuda kernels

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-01-06 20:50:26 +08:00
committed by GitHub
parent 12c8bf38d7
commit 40d376c566
36 changed files with 414 additions and 390 deletions

View File

@@ -1,4 +1,6 @@
import re
import subprocess
from typing import List
def get_cuda_bare_metal_version(cuda_dir):
@@ -11,6 +13,26 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor
def get_cuda_cc_flag() -> List:
"""get_cuda_cc_flag
cc flag for your GPU arch
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
return cc_flag
def append_nvcc_threads(nvcc_extra_args):
from torch.utils.cpp_extension import CUDA_HOME