mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[setup] support pre-build and jit-build of cuda kernels (#2374)
* [setup] support pre-build and jit-build of cuda kernels * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import re
|
||||
import subprocess
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
@@ -11,6 +13,26 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
def get_cuda_cc_flag() -> List:
|
||||
"""get_cuda_cc_flag
|
||||
|
||||
cc flag for your GPU arch
|
||||
"""
|
||||
|
||||
# only import torch when needed
|
||||
# this is to avoid importing torch when building on a machine without torch pre-installed
|
||||
# one case is to build wheel for pypi release
|
||||
import torch
|
||||
|
||||
cc_flag = []
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r'sm_(\d+)', arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60:
|
||||
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||
|
||||
return cc_flag
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
Reference in New Issue
Block a user