mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 11:31:12 +00:00
feat: Support windows
fix: Fix install error on linux doc: Add torch cuda FAQ
This commit is contained in:
210
setup.py
210
setup.py
@@ -5,12 +5,19 @@ import platform
|
||||
import subprocess
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
import urllib.request
|
||||
from urllib.parse import urlparse, quote
|
||||
import re
|
||||
from pip._internal.utils.appdirs import user_cache_dir
|
||||
import shutil
|
||||
import tempfile
|
||||
from setuptools import find_packages
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "false").lower() == "true"
|
||||
|
||||
|
||||
def parse_requirements(file_name: str) -> List[str]:
|
||||
with open(file_name) as f:
|
||||
@@ -21,9 +28,70 @@ def parse_requirements(file_name: str) -> List[str]:
|
||||
]
|
||||
|
||||
|
||||
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
||||
command = [
|
||||
"python",
|
||||
"-m",
|
||||
"pip",
|
||||
"index",
|
||||
"versions",
|
||||
package_name,
|
||||
"--index-url",
|
||||
index_url,
|
||||
]
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
if result.returncode != 0:
|
||||
print("Error executing command.")
|
||||
print(result.stderr.decode())
|
||||
return default_version
|
||||
|
||||
output = result.stdout.decode()
|
||||
lines = output.split("\n")
|
||||
for line in lines:
|
||||
if "Available versions:" in line:
|
||||
available_versions = line.split(":")[1].strip()
|
||||
latest_version = available_versions.split(",")[0].strip()
|
||||
return latest_version
|
||||
|
||||
return default_version
|
||||
|
||||
|
||||
def encode_url(package_url: str) -> str:
|
||||
parsed_url = urlparse(package_url)
|
||||
encoded_path = quote(parsed_url.path)
|
||||
safe_url = parsed_url._replace(path=encoded_path).geturl()
|
||||
return safe_url, parsed_url.path
|
||||
|
||||
|
||||
def cache_package(package_url: str, package_name: str, is_windows: bool = False):
|
||||
safe_url, parsed_url = encode_url(package_url)
|
||||
if BUILD_NO_CACHE:
|
||||
return safe_url
|
||||
filename = os.path.basename(parsed_url)
|
||||
cache_dir = os.path.join(user_cache_dir("pip"), "http", "wheels", package_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
local_path = os.path.join(cache_dir, filename)
|
||||
if not os.path.exists(local_path):
|
||||
# temp_file, temp_path = tempfile.mkstemp()
|
||||
temp_path = local_path + ".tmp"
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
try:
|
||||
print(f"Download {safe_url} to {local_path}")
|
||||
urllib.request.urlretrieve(safe_url, temp_path)
|
||||
shutil.move(temp_path, local_path)
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
return f"file:///{local_path}" if is_windows else f"file://{local_path}"
|
||||
|
||||
|
||||
class SetupSpec:
|
||||
def __init__(self) -> None:
|
||||
self.extras: dict = {}
|
||||
self.install_requires: List[str] = []
|
||||
|
||||
|
||||
setup_spec = SetupSpec()
|
||||
@@ -56,22 +124,27 @@ def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
|
||||
cpu_avx = AVXType.BASIC
|
||||
env_cpu_avx = AVXType.of_type(os.getenv("DBGPT_LLAMA_CPP_AVX"))
|
||||
|
||||
cmds = ["lscpu"]
|
||||
if system == "Windows":
|
||||
cmds = ["coreinfo"]
|
||||
if "windows" in system.lower():
|
||||
os_type = OSType.WINDOWS
|
||||
output = "avx2"
|
||||
print("Current platform is windows, use avx2 as default cpu architecture")
|
||||
elif system == "Linux":
|
||||
cmds = ["lscpu"]
|
||||
os_type = OSType.LINUX
|
||||
result = subprocess.run(
|
||||
["lscpu"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
output = result.stdout.decode()
|
||||
elif system == "Darwin":
|
||||
cmds = ["sysctl", "-a"]
|
||||
os_type = OSType.DARWIN
|
||||
result = subprocess.run(
|
||||
["sysctl", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
output = result.stdout.decode()
|
||||
else:
|
||||
os_type = OSType.OTHER
|
||||
print("Unsupported OS to get cpu avx, use default")
|
||||
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx
|
||||
result = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output = result.stdout.decode()
|
||||
|
||||
if "avx512" in output.lower():
|
||||
cpu_avx = AVXType.AVX512
|
||||
elif "avx2" in output.lower():
|
||||
@@ -82,15 +155,97 @@ def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
|
||||
return os_type, env_cpu_avx if env_cpu_avx else cpu_avx
|
||||
|
||||
|
||||
def get_cuda_version() -> str:
|
||||
def get_cuda_version_from_torch():
|
||||
try:
|
||||
import torch
|
||||
|
||||
return torch.version.cuda
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_version_from_nvcc():
|
||||
try:
|
||||
output = subprocess.check_output(["nvcc", "--version"])
|
||||
version_line = [
|
||||
line for line in output.decode("utf-8").split("\n") if "release" in line
|
||||
][0]
|
||||
return version_line.split("release")[-1].strip().split(",")[0]
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_version_from_nvidia_smi():
|
||||
try:
|
||||
output = subprocess.check_output(["nvidia-smi"]).decode("utf-8")
|
||||
match = re.search(r"CUDA Version:\s+(\d+\.\d+)", output)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_version() -> str:
|
||||
try:
|
||||
cuda_version = get_cuda_version_from_torch()
|
||||
if not cuda_version:
|
||||
cuda_version = get_cuda_version_from_nvcc()
|
||||
if not cuda_version:
|
||||
cuda_version = get_cuda_version_from_nvidia_smi()
|
||||
return cuda_version
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def torch_requires(
|
||||
torch_version: str = "2.0.0",
|
||||
torchvision_version: str = "0.15.1",
|
||||
torchaudio_version: str = "2.0.1",
|
||||
):
|
||||
torch_pkgs = []
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
if os_type == OSType.DARWIN:
|
||||
torch_pkgs = [
|
||||
f"torch=={torch_version}",
|
||||
f"torchvision=={torchvision_version}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
else:
|
||||
cuda_version = get_cuda_version()
|
||||
if not cuda_version:
|
||||
torch_pkgs = [
|
||||
f"torch=={torch_version}+cpu",
|
||||
f"torchvision=={torchvision_version}+cpu",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
else:
|
||||
supported_versions = ["11.7", "11.8"]
|
||||
if cuda_version not in supported_versions:
|
||||
print(
|
||||
f"PyTorch version {torch_version} supported cuda version: {supported_versions}, replace to {supported_versions[-1]}"
|
||||
)
|
||||
cuda_version = supported_versions[-1]
|
||||
cuda_version = "cu" + cuda_version.replace(".", "")
|
||||
py_version = "cp310"
|
||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
torch_url_cached = cache_package(
|
||||
torch_url, "torch", os_type == OSType.WINDOWS
|
||||
)
|
||||
torchvision_url_cached = cache_package(
|
||||
torchvision_url, "torchvision", os_type == OSType.WINDOWS
|
||||
)
|
||||
torch_pkgs = [
|
||||
f"torch @ {torch_url_cached}",
|
||||
f"torchvision @ {torchvision_url_cached}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
setup_spec.extras["torch"] = torch_pkgs
|
||||
|
||||
|
||||
def llama_cpp_python_cuda_requires():
|
||||
cuda_version = get_cuda_version()
|
||||
device = "cpu"
|
||||
@@ -105,12 +260,15 @@ def llama_cpp_python_cuda_requires():
|
||||
f"llama_cpp_python_cuda just support in os: {[r._value_ for r in supported_os]}"
|
||||
)
|
||||
return
|
||||
if cpu_avx == AVXType.AVX2 or AVXType.AVX512:
|
||||
cpu_avx = AVXType.AVX
|
||||
cpu_avx = cpu_avx._value_
|
||||
base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui"
|
||||
llama_cpp_version = "0.1.77"
|
||||
py_version = "cp310"
|
||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
extra_index_url, _ = encode_url(extra_index_url)
|
||||
print(f"Install llama_cpp_python_cuda from {extra_index_url}")
|
||||
|
||||
setup_spec.extras["llama_cpp"].append(f"llama_cpp_python_cuda @ {extra_index_url}")
|
||||
@@ -124,6 +282,26 @@ def llama_cpp_requires():
|
||||
llama_cpp_python_cuda_requires()
|
||||
|
||||
|
||||
def quantization_requires():
|
||||
pkgs = []
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
if os_type != OSType.WINDOWS:
|
||||
pkgs = ["bitsandbytes"]
|
||||
else:
|
||||
latest_version = get_latest_version(
|
||||
"bitsandbytes",
|
||||
"https://jllllll.github.io/bitsandbytes-windows-webui",
|
||||
"0.41.1",
|
||||
)
|
||||
extra_index_url = f"https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-{latest_version}-py3-none-win_amd64.whl"
|
||||
local_pkg = cache_package(
|
||||
extra_index_url, "bitsandbytes", os_type == OSType.WINDOWS
|
||||
)
|
||||
pkgs = [f"bitsandbytes @ {local_pkg}"]
|
||||
print(pkgs)
|
||||
setup_spec.extras["quantization"] = pkgs
|
||||
|
||||
|
||||
def all_vector_store_requires():
|
||||
"""
|
||||
pip install "db-gpt[vstore]"
|
||||
@@ -149,12 +327,22 @@ def all_requires():
|
||||
setup_spec.extras["all"] = list(requires)
|
||||
|
||||
|
||||
def init_install_requires():
|
||||
setup_spec.install_requires += parse_requirements("requirements.txt")
|
||||
setup_spec.install_requires += setup_spec.extras["torch"]
|
||||
setup_spec.install_requires += setup_spec.extras["quantization"]
|
||||
print(f"Install requires: \n{','.join(setup_spec.install_requires)}")
|
||||
|
||||
|
||||
torch_requires()
|
||||
llama_cpp_requires()
|
||||
quantization_requires()
|
||||
all_vector_store_requires()
|
||||
all_datasource_requires()
|
||||
|
||||
# must be last
|
||||
all_requires()
|
||||
init_install_requires()
|
||||
|
||||
setuptools.setup(
|
||||
name="db-gpt",
|
||||
@@ -166,7 +354,7 @@ setuptools.setup(
|
||||
" With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
install_requires=parse_requirements("requirements.txt"),
|
||||
install_requires=setup_spec.install_requires,
|
||||
url="https://github.com/eosphoros-ai/DB-GPT",
|
||||
license="https://opensource.org/license/mit/",
|
||||
python_requires=">=3.10",
|
||||
|
Reference in New Issue
Block a user