feat(core): MTB supports multi-user and multi-system fields (#854)

This commit is contained in:
FangYin Cheng 2023-11-27 20:17:56 +08:00 committed by GitHub
parent 20aac6340b
commit eeff46487d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 262 additions and 83 deletions

View File

@ -66,6 +66,7 @@ CREATE TABLE `connect_config` (
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
`comment` text COMMENT 'db comment',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`),
UNIQUE KEY `uk_db` (`db_name`),
KEY `idx_q_db_type` (`db_type`)
@ -78,6 +79,7 @@ CREATE TABLE `chat_history` (
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
@ -110,6 +112,7 @@ CREATE TABLE `my_plugin` (
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
PRIMARY KEY (`id`),
UNIQUE KEY `name` (`name`)
@ -141,6 +144,7 @@ CREATE TABLE `prompt_manage` (
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),

View File

@ -32,6 +32,7 @@ class MyPluginEntity(Base):
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
@ -58,6 +59,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
version=engity.version,
use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0,
sys_code=engity.sys_code,
gmt_created=datetime.now(),
)
session.add(my_plugin)
@ -107,6 +109,8 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
@ -133,6 +137,8 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
count = my_plugins.scalar()
session.close()
return count

View File

@ -128,6 +128,10 @@ LLM_MODEL_CONFIG = {
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
# https://huggingface.co/01-ai/Yi-34B-Chat
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
# https://huggingface.co/01-ai/Yi-34B-Chat-8bits
"yi-34b-chat-8bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-8bits"),
# https://huggingface.co/01-ai/Yi-34B-Chat-4bits
"yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"),
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
}

View File

@ -24,6 +24,7 @@ class ConnectConfigEntity(Base):
db_user = Column(String(255), nullable=True, comment="db user")
db_pwd = Column(String(255), nullable=True, comment="db password")
comment = Column(Text, nullable=True, comment="db comment")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
__table_args__ = (
UniqueConstraint("db_name", name="uk_db"),

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from typing import List, Optional, Dict
from enum import Enum
from pilot.scene.message import OnceConversation
@ -35,11 +35,6 @@ class BaseChatHistoryMemory(ABC):
# def clear(self) -> None:
# """Clear session memory from the local file"""
@abstractmethod
def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list"""
pass
@abstractmethod
def update(self, messages: List[OnceConversation]) -> None:
pass
@ -49,7 +44,7 @@ class BaseChatHistoryMemory(ABC):
pass
@abstractmethod
def conv_info(self, conv_uid: str = None) -> None:
def conv_info(self, conv_uid: Optional[str] = None) -> None:
pass
@abstractmethod
@ -57,5 +52,7 @@ class BaseChatHistoryMemory(ABC):
pass
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
pass
def conv_list(
user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:
"""get user's conversation list"""

View File

@ -1,3 +1,4 @@
from typing import Type
from .base import MemoryStoreType
from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
@ -32,5 +33,5 @@ class ChatHistory:
chat_session_id
)
def get_store_cls(self):
def get_store_cls(self) -> Type[BaseChatHistoryMemory]:
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
@ -32,7 +32,7 @@ class ChatHistoryEntity(Base):
messages = Column(
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
UniqueConstraint("conv_uid", name="uk_conversation")
Index("idx_q_user", "user_name")
Index("idx_q_mode", "chat_mode")
@ -48,11 +48,15 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
session=session,
)
def list_last_20(self, user_name: str = None):
def list_last_20(
self, user_name: Optional[str] = None, sys_code: Optional[str] = None
):
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
if user_name:
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
if sys_code:
chat_history = chat_history.filter(ChatHistoryEntity.sys_code == sys_code)
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())

View File

@ -1,7 +1,7 @@
import json
import os
import duckdb
from typing import List
from typing import List, Dict, Optional
from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory
@ -37,7 +37,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if not result:
# 如果表不存在,则创建新表
self.connect.execute(
"CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)"
"CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)"
)
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
@ -61,8 +61,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)",
[self.chat_seesion_id, chat_mode, summary, user_name, ""],
"INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
[self.chat_seesion_id, chat_mode, summary, user_name, "", ""],
)
cursor.commit()
self.connect.commit()
@ -83,12 +83,13 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
)
else:
cursor.execute(
"INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)",
"INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
[
self.chat_seesion_id,
once_message.chat_mode,
once_message.get_user_conv().content,
"",
once_message.user_name,
once_message.sys_code,
json.dumps(conversations, ensure_ascii=False),
],
)
@ -149,17 +150,26 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return None
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
def conv_list(
user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
query = "SELECT * FROM chat_history"
params = []
conditions = []
if user_name:
cursor.execute(
"SELECT * FROM chat_history where user_name=? order by id desc limit 20",
[user_name],
)
else:
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
# 获取查询结果字段名
conditions.append("user_name = ?")
params.append(user_name)
if sys_code:
conditions.append("sys_code = ?")
params.append(sys_code)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY id DESC LIMIT 20"
cursor.execute(query, params)
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import List
from typing import List, Dict, Optional
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
from pilot.configs.config import Config
@ -62,7 +62,8 @@ class DbHistoryMemory(BaseChatHistoryMemory):
chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode
chat_history.user_name = "default"
chat_history.user_name = once_message.user_name
chat_history.sys_code = once_message.sys_code
chat_history.summary = once_message.get_user_conv().content
conversations.append(_conversation_to_dic(once_message))
@ -92,9 +93,11 @@ class DbHistoryMemory(BaseChatHistoryMemory):
return []
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
def conv_list(
user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:
chat_history_dao = ChatHistoryDao()
history_list = chat_history_dao.list_last_20()
history_list = chat_history_dao.list_last_20(user_name, sys_code)
result = []
for history in history_list:
result.append(history.__dict__)

View File

@ -2,7 +2,7 @@ import json
import uuid
import asyncio
import os
import shutil
import aiofiles
import logging
from fastapi import (
APIRouter,
@ -17,7 +17,7 @@ from fastapi import (
from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError
from typing import List
from typing import List, Optional
import tempfile
from concurrent.futures import Executor
@ -48,7 +48,11 @@ from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer, SpanType
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.utils.executor_utils import (
ExecutorFactory,
blocking_func_to_async,
DefaultExecutorFactory,
)
router = APIRouter()
CFG = Config()
@ -68,9 +72,11 @@ def __get_conv_user_message(conversations: dict):
return ""
def __new_conversation(chat_mode, user_id) -> ConversationVo:
def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo:
unique_id = uuid.uuid1()
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
return ConversationVo(
conv_uid=str(unique_id), chat_mode=chat_mode, sys_code=sys_code
)
def get_db_list():
@ -141,7 +147,9 @@ def get_worker_manager() -> WorkerManager:
def get_executor() -> Executor:
"""Get the global default executor"""
return CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
ComponentType.EXECUTOR_DEFAULT,
ExecutorFactory,
or_register_component=DefaultExecutorFactory,
).create()
@ -166,7 +174,6 @@ async def db_connect_delete(db_name: str = None):
async def async_db_summary_embedding(db_name, db_type):
# 在这里执行需要异步运行的代码
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
db_summary_client.db_summary_embedding(db_name, db_type)
@ -200,16 +207,21 @@ async def db_support_types():
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(user_id: str = None):
async def dialogue_list(
user_name: str = None, user_id: str = None, sys_code: str = None
):
dialogues: List = []
chat_history_service = ChatHistory()
# TODO Change the synchronous call to the asynchronous call
datas = chat_history_service.get_store_cls().conv_list(user_id)
user_name = user_name or user_id
datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code)
for item in datas:
conv_uid = item.get("conv_uid")
summary = item.get("summary")
chat_mode = item.get("chat_mode")
model_name = item.get("model_name", CFG.LLM_MODEL)
user_name = item.get("user_name")
sys_code = item.get("sys_code")
messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x["chat_order"])
@ -223,6 +235,8 @@ async def dialogue_list(user_id: str = None):
chat_mode=chat_mode,
model_name=model_name,
select_param=select_param,
user_name=user_name,
sys_code=sys_code,
)
dialogues.append(conv_vo)
@ -254,9 +268,14 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
chat_mode: str = ChatScene.ChatNormal.value(),
user_name: str = None,
# TODO remove user id
user_id: str = None,
sys_code: str = None,
):
conv_vo = __new_conversation(chat_mode, user_id)
user_name = user_name or user_id
conv_vo = __new_conversation(chat_mode, user_name, sys_code)
return Result.succ(conv_vo)
@ -280,40 +299,40 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
@router.post("/v1/chat/mode/params/file/load")
async def params_load(
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
conv_uid: str,
chat_mode: str,
model_name: str,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
doc_file: UploadFile = File(...),
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
## file save
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
# We can not move temp file in windows system when we open file in context of `with`
tmp_fd, tmp_path = tempfile.mkstemp(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
)
# TODO Use noblocking file save with aiofiles
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename),
)
## chat prepare
# Save the uploaded file
upload_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
os.makedirs(upload_dir, exist_ok=True)
upload_path = os.path.join(upload_dir, doc_file.filename)
async with aiofiles.open(upload_path, "wb") as f:
await f.write(await doc_file.read())
# Prepare the chat
dialogue = ConversationVo(
conv_uid=conv_uid,
chat_mode=chat_mode,
select_param=doc_file.filename,
model_name=model_name,
user_name=user_name,
sys_code=sys_code,
)
chat: BaseChat = await get_chat_instance(dialogue)
resp = await chat.prepare()
### refresh messages
# Refresh messages
return Result.succ(get_hist_messages(conv_uid))
except Exception as e:
logger.error("excel load error!", e)
return Result.failed(code="E000X", msg=f"File Load Error {e}")
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
@router.post("/v1/chat/dialogue/delete")
@ -354,7 +373,9 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value()
if not dialogue.conv_uid:
conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name)
conv_vo = __new_conversation(
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
)
dialogue.conv_uid = conv_vo.conv_uid
if not ChatScene.is_valid_mode(dialogue.chat_mode):
@ -364,13 +385,12 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_name": dialogue.user_name,
"sys_code": dialogue.sys_code,
"current_user_input": dialogue.user_input,
"select_param": dialogue.select_param,
"model_name": dialogue.model_name,
}
# chat: BaseChat = CHAT_FACTORY.get_implementation(
# dialogue.chat_mode, **{"chat_param": chat_param}
# )
chat: BaseChat = await blocking_func_to_async(
get_executor(),
CHAT_FACTORY.get_implementation,
@ -401,8 +421,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
):
chat: BaseChat = await get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",

View File

@ -66,6 +66,8 @@ class ConversationVo(BaseModel):
"""
incremental: bool = False
sys_code: Optional[str] = None
class MessageVo(BaseModel):
"""

View File

@ -78,7 +78,9 @@ class BaseChat(ABC):
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(
self.chat_mode.value()
self.chat_mode.value(),
user_name=chat_param.get("user_name"),
sys_code=chat_param.get("sys_code"),
)
self.current_message.model_name = self.llm_model
if chat_param["select_param"]:
@ -171,7 +173,6 @@ class BaseChat(ABC):
"messages": llm_messages,
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
# "stop": self.prompt_template.sep,
"echo": self.llm_echo,
}
return payload

View File

@ -18,7 +18,7 @@ class OnceConversation:
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
"""
def __init__(self, chat_mode):
def __init__(self, chat_mode, user_name: str = None, sys_code: str = None):
self.chat_mode: str = chat_mode
self.messages: List[BaseMessage] = []
self.start_date: str = ""
@ -28,6 +28,8 @@ class OnceConversation:
self.param_value: str = ""
self.cost: int = 0
self.tokens: int = 0
self.user_name: str = user_name
self.sys_code: str = sys_code
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
@ -113,6 +115,8 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"messages": messages_to_dict(once.messages),
"param_type": once.param_type,
"param_value": once.param_value,
"user_name": once.user_name,
"sys_code": once.sys_code,
}
@ -121,7 +125,9 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
def conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation()
conversation = OnceConversation(
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
)
conversation.cost = once.get("cost", 0)
conversation.chat_mode = once.get("chat_mode", "chat_normal")
conversation.tokens = once.get("tokens", 0)

View File

@ -29,6 +29,7 @@ class PromptManageEntity(Base):
prompt_name = Column(String(512))
content = Column(Text)
user_name = Column(String(128))
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)

147
setup.py
View File

@ -1,5 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional, Callable
import setuptools
import platform
import subprocess
@ -10,6 +9,7 @@ from urllib.parse import urlparse, quote
import re
import shutil
from setuptools import find_packages
import functools
with open("README.md", mode="r", encoding="utf-8") as fh:
long_description = fh.read()
@ -34,8 +34,15 @@ def parse_requirements(file_name: str) -> List[str]:
def get_latest_version(package_name: str, index_url: str, default_version: str):
python_command = shutil.which("python")
if not python_command:
python_command = shutil.which("python3")
if not python_command:
print("Python command not found.")
return default_version
command = [
"python",
python_command,
"-m",
"pip",
"index",
@ -125,6 +132,7 @@ class OSType(Enum):
OTHER = "other"
@functools.cache
def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
system = platform.system()
os_type = OSType.OTHER
@ -206,6 +214,57 @@ def get_cuda_version() -> str:
return None
def _build_wheels(
pkg_name: str,
pkg_version: str,
base_url: str = None,
base_url_func: Callable[[str, str, str], str] = None,
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
supported_cuda_versions: List[str] = ["11.7", "11.8"],
) -> Optional[str]:
"""
Build the URL for the package wheel file based on the package name, version, and CUDA version.
Args:
pkg_name (str): The name of the package.
pkg_version (str): The version of the package.
base_url (str): The base URL for downloading the package.
base_url_func (Callable): A function to generate the base URL.
pkg_file_func (Callable): build package file function.
function params: pkg_name, pkg_version, cuda_version, py_version, OSType
supported_cuda_versions (List[str]): The list of supported CUDA versions.
Returns:
Optional[str]: The URL for the package wheel file.
"""
os_type, _ = get_cpu_avx_support()
cuda_version = get_cuda_version()
py_version = platform.python_version()
py_version = "cp" + "".join(py_version.split(".")[0:2])
if os_type == OSType.DARWIN or not cuda_version:
return None
if cuda_version not in supported_cuda_versions:
print(
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
)
cuda_version = supported_cuda_versions[-1]
cuda_version = "cu" + cuda_version.replace(".", "")
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
if base_url_func:
base_url = base_url_func(pkg_version, cuda_version, py_version)
if base_url and base_url.endswith("/"):
base_url = base_url[:-1]
if pkg_file_func:
full_pkg_file = pkg_file_func(
pkg_name, pkg_version, cuda_version, py_version, os_type
)
else:
full_pkg_file = f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
if not base_url:
return full_pkg_file
else:
return f"{base_url}/{full_pkg_file}"
def torch_requires(
torch_version: str = "2.0.1",
torchvision_version: str = "0.15.2",
@ -222,16 +281,20 @@ def torch_requires(
cuda_version = get_cuda_version()
if cuda_version:
supported_versions = ["11.7", "11.8"]
if cuda_version not in supported_versions:
print(
f"PyTorch version {torch_version} supported cuda version: {supported_versions}, replace to {supported_versions[-1]}"
)
cuda_version = supported_versions[-1]
cuda_version = "cu" + cuda_version.replace(".", "")
py_version = "cp310"
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
# torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
# torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
torch_url = _build_wheels(
"torch",
torch_version,
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
supported_cuda_versions=supported_versions,
)
torchvision_url = _build_wheels(
"torchvision",
torch_version,
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
supported_cuda_versions=supported_versions,
)
torch_url_cached = cache_package(
torch_url, "torch", os_type == OSType.WINDOWS
)
@ -327,6 +390,7 @@ def core_requires():
"xlrd==2.0.1",
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
"pympler",
"aiofiles",
]
if BUILD_FROM_SOURCE:
setup_spec.extras["framework"].append(
@ -360,6 +424,41 @@ def llama_cpp_requires():
llama_cpp_python_cuda_requires()
def _build_autoawq_requires() -> Optional[str]:
os_type, _ = get_cpu_avx_support()
if os_type == OSType.DARWIN:
return None
auto_gptq_version = get_latest_version(
"auto-gptq", "https://huggingface.github.io/autogptq-index/whl/cu118/", "0.5.1"
)
# eg. 0.5.1+cu118
auto_gptq_version = auto_gptq_version.split("+")[0]
def pkg_file_func(pkg_name, pkg_version, cuda_version, py_version, os_type):
pkg_name = pkg_name.replace("-", "_")
if os_type == OSType.DARWIN:
return None
os_pkg_name = (
"manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
if os_type == OSType.LINUX
else "win_amd64.whl"
)
return f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}"
auto_gptq_url = _build_wheels(
"auto-gptq",
auto_gptq_version,
base_url_func=lambda v, x, y: f"https://huggingface.github.io/autogptq-index/whl/{x}/auto-gptq",
pkg_file_func=pkg_file_func,
supported_cuda_versions=["11.8"],
)
if auto_gptq_url:
print(f"Install auto-gptq from {auto_gptq_url}")
return f"auto-gptq @ {auto_gptq_url}"
else:
"auto-gptq"
def quantization_requires():
pkgs = []
os_type, _ = get_cpu_avx_support()
@ -379,6 +478,28 @@ def quantization_requires():
print(pkgs)
# For chatglm2-6b-int4
pkgs += ["cpm_kernels"]
# Since transformers 4.35.0, the GPT-Q/AWQ model can be loaded using AutoModelForCausalLM.
# autoawq requirements:
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
# 2. CUDA Toolkit 11.8 and later.
autoawq_url = _build_wheels(
"autoawq",
"0.1.7",
base_url_func=lambda v, x, y: f"https://github.com/casper-hansen/AutoAWQ/releases/download/v{v}",
supported_cuda_versions=["11.8"],
)
if autoawq_url:
print(f"Install autoawq from {autoawq_url}")
pkgs.append(f"autoawq @ {autoawq_url}")
else:
pkgs.append("autoawq")
auto_gptq_pkg = _build_autoawq_requires()
if auto_gptq_pkg:
pkgs.append(auto_gptq_pkg)
pkgs.append("optimum")
setup_spec.extras["quantization"] = pkgs