mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 06:47:30 +00:00
feat: Upgrade torch version to 2.2.1 (#1374)
Co-authored-by: yyhhyy <yyhhyyyyyy@163.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
parent
bb77e13dee
commit
5a0c20a864
@ -1,4 +1,4 @@
|
|||||||
ARG BASE_IMAGE="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
|
ARG BASE_IMAGE="nvidia/cuda:12.1.0-runtime-ubuntu22.04"
|
||||||
|
|
||||||
FROM ${BASE_IMAGE}
|
FROM ${BASE_IMAGE}
|
||||||
ARG BASE_IMAGE
|
ARG BASE_IMAGE
|
||||||
|
@ -4,7 +4,7 @@ SCRIPT_LOCATION=$0
|
|||||||
cd "$(dirname "$SCRIPT_LOCATION")"
|
cd "$(dirname "$SCRIPT_LOCATION")"
|
||||||
WORK_DIR=$(pwd)
|
WORK_DIR=$(pwd)
|
||||||
|
|
||||||
BASE_IMAGE_DEFAULT="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
|
BASE_IMAGE_DEFAULT="nvidia/cuda:12.1.0-runtime-ubuntu22.04"
|
||||||
BASE_IMAGE_DEFAULT_CPU="ubuntu:22.04"
|
BASE_IMAGE_DEFAULT_CPU="ubuntu:22.04"
|
||||||
|
|
||||||
BASE_IMAGE=$BASE_IMAGE_DEFAULT
|
BASE_IMAGE=$BASE_IMAGE_DEFAULT
|
||||||
@ -21,7 +21,7 @@ BUILD_NETWORK=""
|
|||||||
DB_GPT_INSTALL_MODEL="default"
|
DB_GPT_INSTALL_MODEL="default"
|
||||||
|
|
||||||
usage () {
|
usage () {
|
||||||
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-runtime-ubuntu22.04] [--image-name db-gpt]"
|
echo "USAGE: $0 [--base-image nvidia/cuda:12.1.0-runtime-ubuntu22.04] [--image-name db-gpt]"
|
||||||
echo " [-b|--base-image base image name] Base image name"
|
echo " [-b|--base-image base image name] Base image name"
|
||||||
echo " [-n|--image-name image name] Current image name, default: db-gpt"
|
echo " [-n|--image-name image name] Current image name, default: db-gpt"
|
||||||
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
|
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
|
||||||
|
113
setup.py
113
setup.py
@ -4,6 +4,7 @@ import platform
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
@ -40,15 +41,22 @@ def parse_requirements(file_name: str) -> List[str]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def find_python():
|
||||||
|
python_path = sys.executable
|
||||||
|
print(python_path)
|
||||||
|
if not python_path:
|
||||||
|
print("Python command not found.")
|
||||||
|
return None
|
||||||
|
return python_path
|
||||||
|
|
||||||
|
|
||||||
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
||||||
python_command = shutil.which("python")
|
python_command = find_python()
|
||||||
if not python_command:
|
|
||||||
python_command = shutil.which("python3")
|
|
||||||
if not python_command:
|
if not python_command:
|
||||||
print("Python command not found.")
|
print("Python command not found.")
|
||||||
return default_version
|
return default_version
|
||||||
|
|
||||||
command = [
|
command_index_versions = [
|
||||||
python_command,
|
python_command,
|
||||||
"-m",
|
"-m",
|
||||||
"pip",
|
"pip",
|
||||||
@ -59,20 +67,40 @@ def get_latest_version(package_name: str, index_url: str, default_version: str):
|
|||||||
index_url,
|
index_url,
|
||||||
]
|
]
|
||||||
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
result_index_versions = subprocess.run(
|
||||||
if result.returncode != 0:
|
command_index_versions, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
print("Error executing command.")
|
)
|
||||||
print(result.stderr.decode())
|
if result_index_versions.returncode == 0:
|
||||||
return default_version
|
output = result_index_versions.stdout.decode()
|
||||||
|
|
||||||
output = result.stdout.decode()
|
|
||||||
lines = output.split("\n")
|
lines = output.split("\n")
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "Available versions:" in line:
|
if "Available versions:" in line:
|
||||||
available_versions = line.split(":")[1].strip()
|
available_versions = line.split(":")[1].strip()
|
||||||
latest_version = available_versions.split(",")[0].strip()
|
latest_version = available_versions.split(",")[0].strip()
|
||||||
|
# Query for compatibility with the latest version of torch
|
||||||
|
if package_name == "torch" or "torchvision":
|
||||||
|
latest_version = latest_version.split("+")[0]
|
||||||
return latest_version
|
return latest_version
|
||||||
|
else:
|
||||||
|
command_simulate_install = [
|
||||||
|
python_command,
|
||||||
|
"-m",
|
||||||
|
"pip",
|
||||||
|
"install",
|
||||||
|
f"{package_name}==",
|
||||||
|
]
|
||||||
|
|
||||||
|
result_simulate_install = subprocess.run(
|
||||||
|
command_simulate_install, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
print(result_simulate_install)
|
||||||
|
stderr_output = result_simulate_install.stderr.decode()
|
||||||
|
print(stderr_output)
|
||||||
|
match = re.search(r"from versions: (.+?)\)", stderr_output)
|
||||||
|
if match:
|
||||||
|
available_versions = match.group(1).split(", ")
|
||||||
|
latest_version = available_versions[-1].strip()
|
||||||
|
return latest_version
|
||||||
return default_version
|
return default_version
|
||||||
|
|
||||||
|
|
||||||
@ -227,7 +255,7 @@ def _build_wheels(
|
|||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
base_url_func: Callable[[str, str, str], str] = None,
|
base_url_func: Callable[[str, str, str], str] = None,
|
||||||
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
|
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
|
||||||
supported_cuda_versions: List[str] = ["11.7", "11.8"],
|
supported_cuda_versions: List[str] = ["11.8", "12.1"],
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Build the URL for the package wheel file based on the package name, version, and CUDA version.
|
Build the URL for the package wheel file based on the package name, version, and CUDA version.
|
||||||
@ -248,10 +276,16 @@ def _build_wheels(
|
|||||||
py_version = "cp" + "".join(py_version.split(".")[0:2])
|
py_version = "cp" + "".join(py_version.split(".")[0:2])
|
||||||
if os_type == OSType.DARWIN or not cuda_version:
|
if os_type == OSType.DARWIN or not cuda_version:
|
||||||
return None
|
return None
|
||||||
if cuda_version not in supported_cuda_versions:
|
|
||||||
|
if cuda_version in supported_cuda_versions:
|
||||||
|
cuda_version = cuda_version
|
||||||
|
else:
|
||||||
print(
|
print(
|
||||||
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
|
f"Warning: Your CUDA version {cuda_version} is not in our set supported_cuda_versions , we will use our set version."
|
||||||
)
|
)
|
||||||
|
if cuda_version < "12.1":
|
||||||
|
cuda_version = supported_cuda_versions[0]
|
||||||
|
else:
|
||||||
cuda_version = supported_cuda_versions[-1]
|
cuda_version = supported_cuda_versions[-1]
|
||||||
|
|
||||||
cuda_version = "cu" + cuda_version.replace(".", "")
|
cuda_version = "cu" + cuda_version.replace(".", "")
|
||||||
@ -273,48 +307,49 @@ def _build_wheels(
|
|||||||
|
|
||||||
|
|
||||||
def torch_requires(
|
def torch_requires(
|
||||||
torch_version: str = "2.0.1",
|
torch_version: str = "2.2.1",
|
||||||
torchvision_version: str = "0.15.2",
|
torchvision_version: str = "0.17.1",
|
||||||
torchaudio_version: str = "2.0.2",
|
torchaudio_version: str = "2.2.1",
|
||||||
):
|
):
|
||||||
|
os_type, _ = get_cpu_avx_support()
|
||||||
torch_pkgs = [
|
torch_pkgs = [
|
||||||
f"torch=={torch_version}",
|
f"torch=={torch_version}",
|
||||||
f"torchvision=={torchvision_version}",
|
f"torchvision=={torchvision_version}",
|
||||||
f"torchaudio=={torchaudio_version}",
|
f"torchaudio=={torchaudio_version}",
|
||||||
]
|
]
|
||||||
torch_cuda_pkgs = []
|
# Initialize torch_cuda_pkgs for non-Darwin OSes;
|
||||||
os_type, _ = get_cpu_avx_support()
|
# it will be the same as torch_pkgs for Darwin or when no specific CUDA handling is needed
|
||||||
|
torch_cuda_pkgs = torch_pkgs[:]
|
||||||
|
|
||||||
if os_type != OSType.DARWIN:
|
if os_type != OSType.DARWIN:
|
||||||
cuda_version = get_cuda_version()
|
supported_versions = ["11.8", "12.1"]
|
||||||
if cuda_version:
|
base_url_func = lambda v, x, y: f"https://download.pytorch.org/whl/{x}"
|
||||||
supported_versions = ["11.7", "11.8"]
|
|
||||||
# torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
|
||||||
# torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
|
||||||
torch_url = _build_wheels(
|
torch_url = _build_wheels(
|
||||||
"torch",
|
"torch",
|
||||||
torch_version,
|
torch_version,
|
||||||
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
base_url_func=base_url_func,
|
||||||
supported_cuda_versions=supported_versions,
|
supported_cuda_versions=supported_versions,
|
||||||
)
|
)
|
||||||
torchvision_url = _build_wheels(
|
torchvision_url = _build_wheels(
|
||||||
"torchvision",
|
"torchvision",
|
||||||
torchvision_version,
|
torchvision_version,
|
||||||
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
base_url_func=base_url_func,
|
||||||
supported_cuda_versions=supported_versions,
|
supported_cuda_versions=supported_versions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Cache and add CUDA-dependent packages if URLs are available
|
||||||
|
if torch_url:
|
||||||
torch_url_cached = cache_package(
|
torch_url_cached = cache_package(
|
||||||
torch_url, "torch", os_type == OSType.WINDOWS
|
torch_url, "torch", os_type == OSType.WINDOWS
|
||||||
)
|
)
|
||||||
|
torch_cuda_pkgs[0] = f"torch @ {torch_url_cached}"
|
||||||
|
if torchvision_url:
|
||||||
torchvision_url_cached = cache_package(
|
torchvision_url_cached = cache_package(
|
||||||
torchvision_url, "torchvision", os_type == OSType.WINDOWS
|
torchvision_url, "torchvision", os_type == OSType.WINDOWS
|
||||||
)
|
)
|
||||||
|
torch_cuda_pkgs[1] = f"torchvision @ {torchvision_url_cached}"
|
||||||
|
|
||||||
torch_cuda_pkgs = [
|
# Assuming 'setup_spec' is a dictionary where we're adding these dependencies
|
||||||
f"torch @ {torch_url_cached}",
|
|
||||||
f"torchvision @ {torchvision_url_cached}",
|
|
||||||
f"torchaudio=={torchaudio_version}",
|
|
||||||
]
|
|
||||||
|
|
||||||
setup_spec.extras["torch"] = torch_pkgs
|
setup_spec.extras["torch"] = torch_pkgs
|
||||||
setup_spec.extras["torch_cpu"] = torch_pkgs
|
setup_spec.extras["torch_cpu"] = torch_pkgs
|
||||||
setup_spec.extras["torch_cuda"] = torch_cuda_pkgs
|
setup_spec.extras["torch_cuda"] = torch_cuda_pkgs
|
||||||
@ -322,6 +357,7 @@ def torch_requires(
|
|||||||
|
|
||||||
def llama_cpp_python_cuda_requires():
|
def llama_cpp_python_cuda_requires():
|
||||||
cuda_version = get_cuda_version()
|
cuda_version = get_cuda_version()
|
||||||
|
supported_cuda_versions = ["11.8", "12.1"]
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
if not cuda_version:
|
if not cuda_version:
|
||||||
print("CUDA not support, use cpu version")
|
print("CUDA not support, use cpu version")
|
||||||
@ -330,7 +366,10 @@ def llama_cpp_python_cuda_requires():
|
|||||||
print("Disable GPU acceleration")
|
print("Disable GPU acceleration")
|
||||||
return
|
return
|
||||||
# Supports GPU acceleration
|
# Supports GPU acceleration
|
||||||
device = "cu" + cuda_version.replace(".", "")
|
if cuda_version <= "11.8" and not None:
|
||||||
|
device = "cu" + supported_cuda_versions[0].replace(".", "")
|
||||||
|
else:
|
||||||
|
device = "cu" + supported_cuda_versions[-1].replace(".", "")
|
||||||
os_type, cpu_avx = get_cpu_avx_support()
|
os_type, cpu_avx = get_cpu_avx_support()
|
||||||
print(f"OS: {os_type}, cpu avx: {cpu_avx}")
|
print(f"OS: {os_type}, cpu avx: {cpu_avx}")
|
||||||
supported_os = [OSType.WINDOWS, OSType.LINUX]
|
supported_os = [OSType.WINDOWS, OSType.LINUX]
|
||||||
@ -346,7 +385,7 @@ def llama_cpp_python_cuda_requires():
|
|||||||
cpu_device = "basic"
|
cpu_device = "basic"
|
||||||
device += cpu_device
|
device += cpu_device
|
||||||
base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui"
|
base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui"
|
||||||
llama_cpp_version = "0.2.10"
|
llama_cpp_version = "0.2.26"
|
||||||
py_version = "cp310"
|
py_version = "cp310"
|
||||||
os_pkg_name = "manylinux_2_31_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
os_pkg_name = "manylinux_2_31_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||||
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||||
@ -493,6 +532,12 @@ def quantization_requires():
|
|||||||
# autoawq requirements:
|
# autoawq requirements:
|
||||||
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
|
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
|
||||||
# 2. CUDA Toolkit 11.8 and later.
|
# 2. CUDA Toolkit 11.8 and later.
|
||||||
|
cuda_version = get_cuda_version()
|
||||||
|
autoawq_latest_version = get_latest_version("autoawq", "", "0.2.4")
|
||||||
|
if cuda_version is None or cuda_version == "12.1":
|
||||||
|
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
|
||||||
|
else:
|
||||||
|
# TODO(yyhhyy): Add autoawq install method for CUDA version 11.8
|
||||||
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
|
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
|
||||||
|
|
||||||
setup_spec.extras["quantization"] = ["cpm_kernels"] + quantization_pkgs
|
setup_spec.extras["quantization"] = ["cpm_kernels"] + quantization_pkgs
|
||||||
|
Loading…
Reference in New Issue
Block a user