diff --git a/op_builder/utils.py b/op_builder/utils.py index cb528eea6..9412c725b 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -197,11 +197,12 @@ def get_cuda_cc_flag() -> List[str]: import torch cc_flag = [] + max_arch = ''.join(str(i) for i in torch.cuda.get_device_capability()) for arch in torch.cuda.get_arch_list(): res = re.search(r'sm_(\d+)', arch) if res: arch_cap = res[1] - if int(arch_cap) >= 60: + if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) return cc_flag