mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
feat: Upgrade torch version to 2.2.1 (#1374)
Co-authored-by: yyhhyy <yyhhyyyyyy@163.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
parent
bb77e13dee
commit
5a0c20a864
@ -1,4 +1,4 @@
|
||||
ARG BASE_IMAGE="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
|
||||
ARG BASE_IMAGE="nvidia/cuda:12.1.0-runtime-ubuntu22.04"
|
||||
|
||||
FROM ${BASE_IMAGE}
|
||||
ARG BASE_IMAGE
|
||||
|
@ -4,7 +4,7 @@ SCRIPT_LOCATION=$0
|
||||
cd "$(dirname "$SCRIPT_LOCATION")"
|
||||
WORK_DIR=$(pwd)
|
||||
|
||||
BASE_IMAGE_DEFAULT="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
|
||||
BASE_IMAGE_DEFAULT="nvidia/cuda:12.1.0-runtime-ubuntu22.04"
|
||||
BASE_IMAGE_DEFAULT_CPU="ubuntu:22.04"
|
||||
|
||||
BASE_IMAGE=$BASE_IMAGE_DEFAULT
|
||||
@ -21,7 +21,7 @@ BUILD_NETWORK=""
|
||||
DB_GPT_INSTALL_MODEL="default"
|
||||
|
||||
usage () {
|
||||
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-runtime-ubuntu22.04] [--image-name db-gpt]"
|
||||
echo "USAGE: $0 [--base-image nvidia/cuda:12.1.0-runtime-ubuntu22.04] [--image-name db-gpt]"
|
||||
echo " [-b|--base-image base image name] Base image name"
|
||||
echo " [-n|--image-name image name] Current image name, default: db-gpt"
|
||||
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
|
||||
|
155
setup.py
155
setup.py
@ -4,6 +4,7 @@ import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.request
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
@ -40,15 +41,22 @@ def parse_requirements(file_name: str) -> List[str]:
|
||||
]
|
||||
|
||||
|
||||
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
||||
python_command = shutil.which("python")
|
||||
if not python_command:
|
||||
python_command = shutil.which("python3")
|
||||
if not python_command:
|
||||
print("Python command not found.")
|
||||
return default_version
|
||||
def find_python():
|
||||
python_path = sys.executable
|
||||
print(python_path)
|
||||
if not python_path:
|
||||
print("Python command not found.")
|
||||
return None
|
||||
return python_path
|
||||
|
||||
command = [
|
||||
|
||||
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
||||
python_command = find_python()
|
||||
if not python_command:
|
||||
print("Python command not found.")
|
||||
return default_version
|
||||
|
||||
command_index_versions = [
|
||||
python_command,
|
||||
"-m",
|
||||
"pip",
|
||||
@ -59,20 +67,40 @@ def get_latest_version(package_name: str, index_url: str, default_version: str):
|
||||
index_url,
|
||||
]
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
if result.returncode != 0:
|
||||
print("Error executing command.")
|
||||
print(result.stderr.decode())
|
||||
return default_version
|
||||
result_index_versions = subprocess.run(
|
||||
command_index_versions, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
if result_index_versions.returncode == 0:
|
||||
output = result_index_versions.stdout.decode()
|
||||
lines = output.split("\n")
|
||||
for line in lines:
|
||||
if "Available versions:" in line:
|
||||
available_versions = line.split(":")[1].strip()
|
||||
latest_version = available_versions.split(",")[0].strip()
|
||||
# Query for compatibility with the latest version of torch
|
||||
if package_name == "torch" or "torchvision":
|
||||
latest_version = latest_version.split("+")[0]
|
||||
return latest_version
|
||||
else:
|
||||
command_simulate_install = [
|
||||
python_command,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
f"{package_name}==",
|
||||
]
|
||||
|
||||
output = result.stdout.decode()
|
||||
lines = output.split("\n")
|
||||
for line in lines:
|
||||
if "Available versions:" in line:
|
||||
available_versions = line.split(":")[1].strip()
|
||||
latest_version = available_versions.split(",")[0].strip()
|
||||
result_simulate_install = subprocess.run(
|
||||
command_simulate_install, stderr=subprocess.PIPE
|
||||
)
|
||||
print(result_simulate_install)
|
||||
stderr_output = result_simulate_install.stderr.decode()
|
||||
print(stderr_output)
|
||||
match = re.search(r"from versions: (.+?)\)", stderr_output)
|
||||
if match:
|
||||
available_versions = match.group(1).split(", ")
|
||||
latest_version = available_versions[-1].strip()
|
||||
return latest_version
|
||||
|
||||
return default_version
|
||||
|
||||
|
||||
@ -227,7 +255,7 @@ def _build_wheels(
|
||||
base_url: str = None,
|
||||
base_url_func: Callable[[str, str, str], str] = None,
|
||||
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
|
||||
supported_cuda_versions: List[str] = ["11.7", "11.8"],
|
||||
supported_cuda_versions: List[str] = ["11.8", "12.1"],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Build the URL for the package wheel file based on the package name, version, and CUDA version.
|
||||
@ -248,11 +276,17 @@ def _build_wheels(
|
||||
py_version = "cp" + "".join(py_version.split(".")[0:2])
|
||||
if os_type == OSType.DARWIN or not cuda_version:
|
||||
return None
|
||||
if cuda_version not in supported_cuda_versions:
|
||||
|
||||
if cuda_version in supported_cuda_versions:
|
||||
cuda_version = cuda_version
|
||||
else:
|
||||
print(
|
||||
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
|
||||
f"Warning: Your CUDA version {cuda_version} is not in our set supported_cuda_versions , we will use our set version."
|
||||
)
|
||||
cuda_version = supported_cuda_versions[-1]
|
||||
if cuda_version < "12.1":
|
||||
cuda_version = supported_cuda_versions[0]
|
||||
else:
|
||||
cuda_version = supported_cuda_versions[-1]
|
||||
|
||||
cuda_version = "cu" + cuda_version.replace(".", "")
|
||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
@ -273,48 +307,49 @@ def _build_wheels(
|
||||
|
||||
|
||||
def torch_requires(
|
||||
torch_version: str = "2.0.1",
|
||||
torchvision_version: str = "0.15.2",
|
||||
torchaudio_version: str = "2.0.2",
|
||||
torch_version: str = "2.2.1",
|
||||
torchvision_version: str = "0.17.1",
|
||||
torchaudio_version: str = "2.2.1",
|
||||
):
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
torch_pkgs = [
|
||||
f"torch=={torch_version}",
|
||||
f"torchvision=={torchvision_version}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
torch_cuda_pkgs = []
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
# Initialize torch_cuda_pkgs for non-Darwin OSes;
|
||||
# it will be the same as torch_pkgs for Darwin or when no specific CUDA handling is needed
|
||||
torch_cuda_pkgs = torch_pkgs[:]
|
||||
|
||||
if os_type != OSType.DARWIN:
|
||||
cuda_version = get_cuda_version()
|
||||
if cuda_version:
|
||||
supported_versions = ["11.7", "11.8"]
|
||||
# torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
# torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
torch_url = _build_wheels(
|
||||
"torch",
|
||||
torch_version,
|
||||
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
||||
supported_cuda_versions=supported_versions,
|
||||
)
|
||||
torchvision_url = _build_wheels(
|
||||
"torchvision",
|
||||
torchvision_version,
|
||||
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
||||
supported_cuda_versions=supported_versions,
|
||||
)
|
||||
supported_versions = ["11.8", "12.1"]
|
||||
base_url_func = lambda v, x, y: f"https://download.pytorch.org/whl/{x}"
|
||||
torch_url = _build_wheels(
|
||||
"torch",
|
||||
torch_version,
|
||||
base_url_func=base_url_func,
|
||||
supported_cuda_versions=supported_versions,
|
||||
)
|
||||
torchvision_url = _build_wheels(
|
||||
"torchvision",
|
||||
torchvision_version,
|
||||
base_url_func=base_url_func,
|
||||
supported_cuda_versions=supported_versions,
|
||||
)
|
||||
|
||||
# Cache and add CUDA-dependent packages if URLs are available
|
||||
if torch_url:
|
||||
torch_url_cached = cache_package(
|
||||
torch_url, "torch", os_type == OSType.WINDOWS
|
||||
)
|
||||
torch_cuda_pkgs[0] = f"torch @ {torch_url_cached}"
|
||||
if torchvision_url:
|
||||
torchvision_url_cached = cache_package(
|
||||
torchvision_url, "torchvision", os_type == OSType.WINDOWS
|
||||
)
|
||||
torch_cuda_pkgs[1] = f"torchvision @ {torchvision_url_cached}"
|
||||
|
||||
torch_cuda_pkgs = [
|
||||
f"torch @ {torch_url_cached}",
|
||||
f"torchvision @ {torchvision_url_cached}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
|
||||
# Assuming 'setup_spec' is a dictionary where we're adding these dependencies
|
||||
setup_spec.extras["torch"] = torch_pkgs
|
||||
setup_spec.extras["torch_cpu"] = torch_pkgs
|
||||
setup_spec.extras["torch_cuda"] = torch_cuda_pkgs
|
||||
@ -322,6 +357,7 @@ def torch_requires(
|
||||
|
||||
def llama_cpp_python_cuda_requires():
|
||||
cuda_version = get_cuda_version()
|
||||
supported_cuda_versions = ["11.8", "12.1"]
|
||||
device = "cpu"
|
||||
if not cuda_version:
|
||||
print("CUDA not support, use cpu version")
|
||||
@ -330,7 +366,10 @@ def llama_cpp_python_cuda_requires():
|
||||
print("Disable GPU acceleration")
|
||||
return
|
||||
# Supports GPU acceleration
|
||||
device = "cu" + cuda_version.replace(".", "")
|
||||
if cuda_version <= "11.8" and not None:
|
||||
device = "cu" + supported_cuda_versions[0].replace(".", "")
|
||||
else:
|
||||
device = "cu" + supported_cuda_versions[-1].replace(".", "")
|
||||
os_type, cpu_avx = get_cpu_avx_support()
|
||||
print(f"OS: {os_type}, cpu avx: {cpu_avx}")
|
||||
supported_os = [OSType.WINDOWS, OSType.LINUX]
|
||||
@ -346,7 +385,7 @@ def llama_cpp_python_cuda_requires():
|
||||
cpu_device = "basic"
|
||||
device += cpu_device
|
||||
base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui"
|
||||
llama_cpp_version = "0.2.10"
|
||||
llama_cpp_version = "0.2.26"
|
||||
py_version = "cp310"
|
||||
os_pkg_name = "manylinux_2_31_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
@ -493,7 +532,13 @@ def quantization_requires():
|
||||
# autoawq requirements:
|
||||
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
|
||||
# 2. CUDA Toolkit 11.8 and later.
|
||||
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
|
||||
cuda_version = get_cuda_version()
|
||||
autoawq_latest_version = get_latest_version("autoawq", "", "0.2.4")
|
||||
if cuda_version is None or cuda_version == "12.1":
|
||||
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
|
||||
else:
|
||||
# TODO(yyhhyy): Add autoawq install method for CUDA version 11.8
|
||||
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
|
||||
|
||||
setup_spec.extras["quantization"] = ["cpm_kernels"] + quantization_pkgs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user