feat: Upgrade torch version to 2.2.1 (#1374)

Co-authored-by: yyhhyy <yyhhyyyyyy@163.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
yyhhyy 2024-04-08 18:35:19 +08:00 committed by GitHub
parent bb77e13dee
commit 5a0c20a864
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 58 deletions

View File

@ -1,4 +1,4 @@
ARG BASE_IMAGE="nvidia/cuda:11.8.0-runtime-ubuntu22.04" ARG BASE_IMAGE="nvidia/cuda:12.1.0-runtime-ubuntu22.04"
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
ARG BASE_IMAGE ARG BASE_IMAGE

View File

@ -4,7 +4,7 @@ SCRIPT_LOCATION=$0
cd "$(dirname "$SCRIPT_LOCATION")" cd "$(dirname "$SCRIPT_LOCATION")"
WORK_DIR=$(pwd) WORK_DIR=$(pwd)
BASE_IMAGE_DEFAULT="nvidia/cuda:11.8.0-runtime-ubuntu22.04" BASE_IMAGE_DEFAULT="nvidia/cuda:12.1.0-runtime-ubuntu22.04"
BASE_IMAGE_DEFAULT_CPU="ubuntu:22.04" BASE_IMAGE_DEFAULT_CPU="ubuntu:22.04"
BASE_IMAGE=$BASE_IMAGE_DEFAULT BASE_IMAGE=$BASE_IMAGE_DEFAULT
@ -21,7 +21,7 @@ BUILD_NETWORK=""
DB_GPT_INSTALL_MODEL="default" DB_GPT_INSTALL_MODEL="default"
usage () { usage () {
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-runtime-ubuntu22.04] [--image-name db-gpt]" echo "USAGE: $0 [--base-image nvidia/cuda:12.1.0-runtime-ubuntu22.04] [--image-name db-gpt]"
echo " [-b|--base-image base image name] Base image name" echo " [-b|--base-image base image name] Base image name"
echo " [-n|--image-name image name] Current image name, default: db-gpt" echo " [-n|--image-name image name] Current image name, default: db-gpt"
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple" echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"

113
setup.py
View File

@ -4,6 +4,7 @@ import platform
import re import re
import shutil import shutil
import subprocess import subprocess
import sys
import urllib.request import urllib.request
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
@ -40,15 +41,22 @@ def parse_requirements(file_name: str) -> List[str]:
] ]
def find_python():
python_path = sys.executable
print(python_path)
if not python_path:
print("Python command not found.")
return None
return python_path
def get_latest_version(package_name: str, index_url: str, default_version: str): def get_latest_version(package_name: str, index_url: str, default_version: str):
python_command = shutil.which("python") python_command = find_python()
if not python_command:
python_command = shutil.which("python3")
if not python_command: if not python_command:
print("Python command not found.") print("Python command not found.")
return default_version return default_version
command = [ command_index_versions = [
python_command, python_command,
"-m", "-m",
"pip", "pip",
@ -59,20 +67,40 @@ def get_latest_version(package_name: str, index_url: str, default_version: str):
index_url, index_url,
] ]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) result_index_versions = subprocess.run(
if result.returncode != 0: command_index_versions, stdout=subprocess.PIPE, stderr=subprocess.PIPE
print("Error executing command.") )
print(result.stderr.decode()) if result_index_versions.returncode == 0:
return default_version output = result_index_versions.stdout.decode()
output = result.stdout.decode()
lines = output.split("\n") lines = output.split("\n")
for line in lines: for line in lines:
if "Available versions:" in line: if "Available versions:" in line:
available_versions = line.split(":")[1].strip() available_versions = line.split(":")[1].strip()
latest_version = available_versions.split(",")[0].strip() latest_version = available_versions.split(",")[0].strip()
# Query for compatibility with the latest version of torch
if package_name == "torch" or "torchvision":
latest_version = latest_version.split("+")[0]
return latest_version return latest_version
else:
command_simulate_install = [
python_command,
"-m",
"pip",
"install",
f"{package_name}==",
]
result_simulate_install = subprocess.run(
command_simulate_install, stderr=subprocess.PIPE
)
print(result_simulate_install)
stderr_output = result_simulate_install.stderr.decode()
print(stderr_output)
match = re.search(r"from versions: (.+?)\)", stderr_output)
if match:
available_versions = match.group(1).split(", ")
latest_version = available_versions[-1].strip()
return latest_version
return default_version return default_version
@ -227,7 +255,7 @@ def _build_wheels(
base_url: str = None, base_url: str = None,
base_url_func: Callable[[str, str, str], str] = None, base_url_func: Callable[[str, str, str], str] = None,
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None, pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
supported_cuda_versions: List[str] = ["11.7", "11.8"], supported_cuda_versions: List[str] = ["11.8", "12.1"],
) -> Optional[str]: ) -> Optional[str]:
""" """
Build the URL for the package wheel file based on the package name, version, and CUDA version. Build the URL for the package wheel file based on the package name, version, and CUDA version.
@ -248,10 +276,16 @@ def _build_wheels(
py_version = "cp" + "".join(py_version.split(".")[0:2]) py_version = "cp" + "".join(py_version.split(".")[0:2])
if os_type == OSType.DARWIN or not cuda_version: if os_type == OSType.DARWIN or not cuda_version:
return None return None
if cuda_version not in supported_cuda_versions:
if cuda_version in supported_cuda_versions:
cuda_version = cuda_version
else:
print( print(
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}" f"Warning: Your CUDA version {cuda_version} is not in our set supported_cuda_versions , we will use our set version."
) )
if cuda_version < "12.1":
cuda_version = supported_cuda_versions[0]
else:
cuda_version = supported_cuda_versions[-1] cuda_version = supported_cuda_versions[-1]
cuda_version = "cu" + cuda_version.replace(".", "") cuda_version = "cu" + cuda_version.replace(".", "")
@ -273,48 +307,49 @@ def _build_wheels(
def torch_requires( def torch_requires(
torch_version: str = "2.0.1", torch_version: str = "2.2.1",
torchvision_version: str = "0.15.2", torchvision_version: str = "0.17.1",
torchaudio_version: str = "2.0.2", torchaudio_version: str = "2.2.1",
): ):
os_type, _ = get_cpu_avx_support()
torch_pkgs = [ torch_pkgs = [
f"torch=={torch_version}", f"torch=={torch_version}",
f"torchvision=={torchvision_version}", f"torchvision=={torchvision_version}",
f"torchaudio=={torchaudio_version}", f"torchaudio=={torchaudio_version}",
] ]
torch_cuda_pkgs = [] # Initialize torch_cuda_pkgs for non-Darwin OSes;
os_type, _ = get_cpu_avx_support() # it will be the same as torch_pkgs for Darwin or when no specific CUDA handling is needed
torch_cuda_pkgs = torch_pkgs[:]
if os_type != OSType.DARWIN: if os_type != OSType.DARWIN:
cuda_version = get_cuda_version() supported_versions = ["11.8", "12.1"]
if cuda_version: base_url_func = lambda v, x, y: f"https://download.pytorch.org/whl/{x}"
supported_versions = ["11.7", "11.8"]
# torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
# torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
torch_url = _build_wheels( torch_url = _build_wheels(
"torch", "torch",
torch_version, torch_version,
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}", base_url_func=base_url_func,
supported_cuda_versions=supported_versions, supported_cuda_versions=supported_versions,
) )
torchvision_url = _build_wheels( torchvision_url = _build_wheels(
"torchvision", "torchvision",
torchvision_version, torchvision_version,
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}", base_url_func=base_url_func,
supported_cuda_versions=supported_versions, supported_cuda_versions=supported_versions,
) )
# Cache and add CUDA-dependent packages if URLs are available
if torch_url:
torch_url_cached = cache_package( torch_url_cached = cache_package(
torch_url, "torch", os_type == OSType.WINDOWS torch_url, "torch", os_type == OSType.WINDOWS
) )
torch_cuda_pkgs[0] = f"torch @ {torch_url_cached}"
if torchvision_url:
torchvision_url_cached = cache_package( torchvision_url_cached = cache_package(
torchvision_url, "torchvision", os_type == OSType.WINDOWS torchvision_url, "torchvision", os_type == OSType.WINDOWS
) )
torch_cuda_pkgs[1] = f"torchvision @ {torchvision_url_cached}"
torch_cuda_pkgs = [ # Assuming 'setup_spec' is a dictionary where we're adding these dependencies
f"torch @ {torch_url_cached}",
f"torchvision @ {torchvision_url_cached}",
f"torchaudio=={torchaudio_version}",
]
setup_spec.extras["torch"] = torch_pkgs setup_spec.extras["torch"] = torch_pkgs
setup_spec.extras["torch_cpu"] = torch_pkgs setup_spec.extras["torch_cpu"] = torch_pkgs
setup_spec.extras["torch_cuda"] = torch_cuda_pkgs setup_spec.extras["torch_cuda"] = torch_cuda_pkgs
@ -322,6 +357,7 @@ def torch_requires(
def llama_cpp_python_cuda_requires(): def llama_cpp_python_cuda_requires():
cuda_version = get_cuda_version() cuda_version = get_cuda_version()
supported_cuda_versions = ["11.8", "12.1"]
device = "cpu" device = "cpu"
if not cuda_version: if not cuda_version:
print("CUDA not support, use cpu version") print("CUDA not support, use cpu version")
@ -330,7 +366,10 @@ def llama_cpp_python_cuda_requires():
print("Disable GPU acceleration") print("Disable GPU acceleration")
return return
# Supports GPU acceleration # Supports GPU acceleration
device = "cu" + cuda_version.replace(".", "") if cuda_version <= "11.8" and not None:
device = "cu" + supported_cuda_versions[0].replace(".", "")
else:
device = "cu" + supported_cuda_versions[-1].replace(".", "")
os_type, cpu_avx = get_cpu_avx_support() os_type, cpu_avx = get_cpu_avx_support()
print(f"OS: {os_type}, cpu avx: {cpu_avx}") print(f"OS: {os_type}, cpu avx: {cpu_avx}")
supported_os = [OSType.WINDOWS, OSType.LINUX] supported_os = [OSType.WINDOWS, OSType.LINUX]
@ -346,7 +385,7 @@ def llama_cpp_python_cuda_requires():
cpu_device = "basic" cpu_device = "basic"
device += cpu_device device += cpu_device
base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui" base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui"
llama_cpp_version = "0.2.10" llama_cpp_version = "0.2.26"
py_version = "cp310" py_version = "cp310"
os_pkg_name = "manylinux_2_31_x86_64" if os_type == OSType.LINUX else "win_amd64" os_pkg_name = "manylinux_2_31_x86_64" if os_type == OSType.LINUX else "win_amd64"
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl" extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
@ -493,6 +532,12 @@ def quantization_requires():
# autoawq requirements: # autoawq requirements:
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported. # 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
# 2. CUDA Toolkit 11.8 and later. # 2. CUDA Toolkit 11.8 and later.
cuda_version = get_cuda_version()
autoawq_latest_version = get_latest_version("autoawq", "", "0.2.4")
if cuda_version is None or cuda_version == "12.1":
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
else:
# TODO(yyhhyy): Add autoawq install method for CUDA version 11.8
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"]) quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
setup_spec.extras["quantization"] = ["cpm_kernels"] + quantization_pkgs setup_spec.extras["quantization"] = ["cpm_kernels"] + quantization_pkgs