mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 03:20:41 +00:00
feat(core): MTB supports multi-user and multi-system fields (#854)
This commit is contained in:
147
setup.py
147
setup.py
@@ -1,5 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from typing import List, Tuple, Optional, Callable
|
||||
import setuptools
|
||||
import platform
|
||||
import subprocess
|
||||
@@ -10,6 +9,7 @@ from urllib.parse import urlparse, quote
|
||||
import re
|
||||
import shutil
|
||||
from setuptools import find_packages
|
||||
import functools
|
||||
|
||||
with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
@@ -34,8 +34,15 @@ def parse_requirements(file_name: str) -> List[str]:
|
||||
|
||||
|
||||
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
||||
python_command = shutil.which("python")
|
||||
if not python_command:
|
||||
python_command = shutil.which("python3")
|
||||
if not python_command:
|
||||
print("Python command not found.")
|
||||
return default_version
|
||||
|
||||
command = [
|
||||
"python",
|
||||
python_command,
|
||||
"-m",
|
||||
"pip",
|
||||
"index",
|
||||
@@ -125,6 +132,7 @@ class OSType(Enum):
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
|
||||
system = platform.system()
|
||||
os_type = OSType.OTHER
|
||||
@@ -206,6 +214,57 @@ def get_cuda_version() -> str:
|
||||
return None
|
||||
|
||||
|
||||
def _build_wheels(
|
||||
pkg_name: str,
|
||||
pkg_version: str,
|
||||
base_url: str = None,
|
||||
base_url_func: Callable[[str, str, str], str] = None,
|
||||
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
|
||||
supported_cuda_versions: List[str] = ["11.7", "11.8"],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Build the URL for the package wheel file based on the package name, version, and CUDA version.
|
||||
Args:
|
||||
pkg_name (str): The name of the package.
|
||||
pkg_version (str): The version of the package.
|
||||
base_url (str): The base URL for downloading the package.
|
||||
base_url_func (Callable): A function to generate the base URL.
|
||||
pkg_file_func (Callable): build package file function.
|
||||
function params: pkg_name, pkg_version, cuda_version, py_version, OSType
|
||||
supported_cuda_versions (List[str]): The list of supported CUDA versions.
|
||||
Returns:
|
||||
Optional[str]: The URL for the package wheel file.
|
||||
"""
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
cuda_version = get_cuda_version()
|
||||
py_version = platform.python_version()
|
||||
py_version = "cp" + "".join(py_version.split(".")[0:2])
|
||||
if os_type == OSType.DARWIN or not cuda_version:
|
||||
return None
|
||||
if cuda_version not in supported_cuda_versions:
|
||||
print(
|
||||
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
|
||||
)
|
||||
cuda_version = supported_cuda_versions[-1]
|
||||
|
||||
cuda_version = "cu" + cuda_version.replace(".", "")
|
||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
if base_url_func:
|
||||
base_url = base_url_func(pkg_version, cuda_version, py_version)
|
||||
if base_url and base_url.endswith("/"):
|
||||
base_url = base_url[:-1]
|
||||
if pkg_file_func:
|
||||
full_pkg_file = pkg_file_func(
|
||||
pkg_name, pkg_version, cuda_version, py_version, os_type
|
||||
)
|
||||
else:
|
||||
full_pkg_file = f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
if not base_url:
|
||||
return full_pkg_file
|
||||
else:
|
||||
return f"{base_url}/{full_pkg_file}"
|
||||
|
||||
|
||||
def torch_requires(
|
||||
torch_version: str = "2.0.1",
|
||||
torchvision_version: str = "0.15.2",
|
||||
@@ -222,16 +281,20 @@ def torch_requires(
|
||||
cuda_version = get_cuda_version()
|
||||
if cuda_version:
|
||||
supported_versions = ["11.7", "11.8"]
|
||||
if cuda_version not in supported_versions:
|
||||
print(
|
||||
f"PyTorch version {torch_version} supported cuda version: {supported_versions}, replace to {supported_versions[-1]}"
|
||||
)
|
||||
cuda_version = supported_versions[-1]
|
||||
cuda_version = "cu" + cuda_version.replace(".", "")
|
||||
py_version = "cp310"
|
||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||
torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
# torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
# torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||
torch_url = _build_wheels(
|
||||
"torch",
|
||||
torch_version,
|
||||
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
||||
supported_cuda_versions=supported_versions,
|
||||
)
|
||||
torchvision_url = _build_wheels(
|
||||
"torchvision",
|
||||
torch_version,
|
||||
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
||||
supported_cuda_versions=supported_versions,
|
||||
)
|
||||
torch_url_cached = cache_package(
|
||||
torch_url, "torch", os_type == OSType.WINDOWS
|
||||
)
|
||||
@@ -327,6 +390,7 @@ def core_requires():
|
||||
"xlrd==2.0.1",
|
||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||
"pympler",
|
||||
"aiofiles",
|
||||
]
|
||||
if BUILD_FROM_SOURCE:
|
||||
setup_spec.extras["framework"].append(
|
||||
@@ -360,6 +424,41 @@ def llama_cpp_requires():
|
||||
llama_cpp_python_cuda_requires()
|
||||
|
||||
|
||||
def _build_autoawq_requires() -> Optional[str]:
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
if os_type == OSType.DARWIN:
|
||||
return None
|
||||
auto_gptq_version = get_latest_version(
|
||||
"auto-gptq", "https://huggingface.github.io/autogptq-index/whl/cu118/", "0.5.1"
|
||||
)
|
||||
# eg. 0.5.1+cu118
|
||||
auto_gptq_version = auto_gptq_version.split("+")[0]
|
||||
|
||||
def pkg_file_func(pkg_name, pkg_version, cuda_version, py_version, os_type):
|
||||
pkg_name = pkg_name.replace("-", "_")
|
||||
if os_type == OSType.DARWIN:
|
||||
return None
|
||||
os_pkg_name = (
|
||||
"manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
|
||||
if os_type == OSType.LINUX
|
||||
else "win_amd64.whl"
|
||||
)
|
||||
return f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}"
|
||||
|
||||
auto_gptq_url = _build_wheels(
|
||||
"auto-gptq",
|
||||
auto_gptq_version,
|
||||
base_url_func=lambda v, x, y: f"https://huggingface.github.io/autogptq-index/whl/{x}/auto-gptq",
|
||||
pkg_file_func=pkg_file_func,
|
||||
supported_cuda_versions=["11.8"],
|
||||
)
|
||||
if auto_gptq_url:
|
||||
print(f"Install auto-gptq from {auto_gptq_url}")
|
||||
return f"auto-gptq @ {auto_gptq_url}"
|
||||
else:
|
||||
"auto-gptq"
|
||||
|
||||
|
||||
def quantization_requires():
|
||||
pkgs = []
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
@@ -379,6 +478,28 @@ def quantization_requires():
|
||||
print(pkgs)
|
||||
# For chatglm2-6b-int4
|
||||
pkgs += ["cpm_kernels"]
|
||||
|
||||
# Since transformers 4.35.0, the GPT-Q/AWQ model can be loaded using AutoModelForCausalLM.
|
||||
# autoawq requirements:
|
||||
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
|
||||
# 2. CUDA Toolkit 11.8 and later.
|
||||
autoawq_url = _build_wheels(
|
||||
"autoawq",
|
||||
"0.1.7",
|
||||
base_url_func=lambda v, x, y: f"https://github.com/casper-hansen/AutoAWQ/releases/download/v{v}",
|
||||
supported_cuda_versions=["11.8"],
|
||||
)
|
||||
if autoawq_url:
|
||||
print(f"Install autoawq from {autoawq_url}")
|
||||
pkgs.append(f"autoawq @ {autoawq_url}")
|
||||
else:
|
||||
pkgs.append("autoawq")
|
||||
|
||||
auto_gptq_pkg = _build_autoawq_requires()
|
||||
if auto_gptq_pkg:
|
||||
pkgs.append(auto_gptq_pkg)
|
||||
pkgs.append("optimum")
|
||||
|
||||
setup_spec.extras["quantization"] = pkgs
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user