diff --git a/setup.py b/setup.py index b9cd9e5e4..7cfbbe9b1 100644 --- a/setup.py +++ b/setup.py @@ -117,14 +117,26 @@ def get_version(): with open(version_txt_path) as f: version = f.read().strip() - if build_cuda_ext: - torch_version = '.'.join(torch.__version__.split('.')[:2]) - cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)[1:]) - version += f'+torch{torch_version}cu{cuda_version}' # write version into version.py with open(version_py_path, 'w') as f: f.write(f"__version__ = '{version}'\n") + if build_cuda_ext: + torch_version = '.'.join(torch.__version__.split('.')[:2]) + cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)[1:]) + else: + torch_version = None + cuda_version = None + + if torch_version: + f.write(f'torch = "{torch_version}"\n') + else: + f.write('torch = None\n') + + if cuda_version: + f.write(f'cuda = "{cuda_version}"\n') + else: + f.write('cuda = None\n') return version