[builder] multihead attn runtime building (#2203)

* [hotfix] correcnt cpu_optim runtime compilation

* [builder] multihead attn

* fix bug

* fix a bug
This commit is contained in:
Jiarui Fang
2022-12-27 16:06:09 +08:00
committed by GitHub
parent 8e22c38b89
commit 1cb532ffec
7 changed files with 88 additions and 25 deletions

View File

@@ -1,7 +1,26 @@
import os
import re
import sys
from pathlib import Path
import torch
def get_cuda_cc_flag():
"""get_cuda_cc_flag
cc flag for your GPU arch
"""
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
if res:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
return cc_flag
class Builder(object):