mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-17 15:58:25 +00:00
feat(core): MTB supports multi-user and multi-system fields (#854)
This commit is contained in:
parent
20aac6340b
commit
eeff46487d
@ -66,6 +66,7 @@ CREATE TABLE `connect_config` (
|
|||||||
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
|
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
|
||||||
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
|
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
|
||||||
`comment` text COMMENT 'db comment',
|
`comment` text COMMENT 'db comment',
|
||||||
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
PRIMARY KEY (`id`),
|
PRIMARY KEY (`id`),
|
||||||
UNIQUE KEY `uk_db` (`db_name`),
|
UNIQUE KEY `uk_db` (`db_name`),
|
||||||
KEY `idx_q_db_type` (`db_type`)
|
KEY `idx_q_db_type` (`db_type`)
|
||||||
@ -78,6 +79,7 @@ CREATE TABLE `chat_history` (
|
|||||||
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
|
||||||
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
|
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
|
||||||
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
|
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
|
||||||
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
PRIMARY KEY (`id`)
|
PRIMARY KEY (`id`)
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
|
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
|
||||||
|
|
||||||
@ -110,6 +112,7 @@ CREATE TABLE `my_plugin` (
|
|||||||
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
|
||||||
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
|
`use_count` int DEFAULT NULL COMMENT 'plugin total use count',
|
||||||
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
|
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
|
||||||
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
|
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
|
||||||
PRIMARY KEY (`id`),
|
PRIMARY KEY (`id`),
|
||||||
UNIQUE KEY `name` (`name`)
|
UNIQUE KEY `name` (`name`)
|
||||||
@ -141,6 +144,7 @@ CREATE TABLE `prompt_manage` (
|
|||||||
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
|
||||||
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
|
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
|
||||||
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
|
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
|
||||||
|
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
|
||||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
PRIMARY KEY (`id`),
|
PRIMARY KEY (`id`),
|
||||||
|
@ -32,6 +32,7 @@ class MyPluginEntity(Base):
|
|||||||
succ_count = Column(
|
succ_count = Column(
|
||||||
Integer, nullable=True, default=0, comment="plugin total success count"
|
Integer, nullable=True, default=0, comment="plugin total success count"
|
||||||
)
|
)
|
||||||
|
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||||
gmt_created = Column(
|
gmt_created = Column(
|
||||||
DateTime, default=datetime.utcnow, comment="plugin install time"
|
DateTime, default=datetime.utcnow, comment="plugin install time"
|
||||||
)
|
)
|
||||||
@ -58,6 +59,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
version=engity.version,
|
version=engity.version,
|
||||||
use_count=engity.use_count or 0,
|
use_count=engity.use_count or 0,
|
||||||
succ_count=engity.succ_count or 0,
|
succ_count=engity.succ_count or 0,
|
||||||
|
sys_code=engity.sys_code,
|
||||||
gmt_created=datetime.now(),
|
gmt_created=datetime.now(),
|
||||||
)
|
)
|
||||||
session.add(my_plugin)
|
session.add(my_plugin)
|
||||||
@ -107,6 +109,8 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||||
if query.user_name is not None:
|
if query.user_name is not None:
|
||||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||||
|
if query.sys_code is not None:
|
||||||
|
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||||
|
|
||||||
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
||||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
|
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
|
||||||
@ -133,6 +137,8 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||||
if query.user_name is not None:
|
if query.user_name is not None:
|
||||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||||
|
if query.sys_code is not None:
|
||||||
|
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
|
||||||
count = my_plugins.scalar()
|
count = my_plugins.scalar()
|
||||||
session.close()
|
session.close()
|
||||||
return count
|
return count
|
||||||
|
@ -128,6 +128,10 @@ LLM_MODEL_CONFIG = {
|
|||||||
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
|
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
|
||||||
# https://huggingface.co/01-ai/Yi-34B-Chat
|
# https://huggingface.co/01-ai/Yi-34B-Chat
|
||||||
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
|
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
|
||||||
|
# https://huggingface.co/01-ai/Yi-34B-Chat-8bits
|
||||||
|
"yi-34b-chat-8bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-8bits"),
|
||||||
|
# https://huggingface.co/01-ai/Yi-34B-Chat-4bits
|
||||||
|
"yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"),
|
||||||
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
|
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ class ConnectConfigEntity(Base):
|
|||||||
db_user = Column(String(255), nullable=True, comment="db user")
|
db_user = Column(String(255), nullable=True, comment="db user")
|
||||||
db_pwd = Column(String(255), nullable=True, comment="db password")
|
db_pwd = Column(String(255), nullable=True, comment="db password")
|
||||||
comment = Column(Text, nullable=True, comment="db comment")
|
comment = Column(Text, nullable=True, comment="db comment")
|
||||||
|
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("db_name", name="uk_db"),
|
UniqueConstraint("db_name", name="uk_db"),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List, Optional, Dict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
|
|
||||||
@ -35,11 +35,6 @@ class BaseChatHistoryMemory(ABC):
|
|||||||
# def clear(self) -> None:
|
# def clear(self) -> None:
|
||||||
# """Clear session memory from the local file"""
|
# """Clear session memory from the local file"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def conv_list(self, user_name: str = None) -> None:
|
|
||||||
"""get user's conversation list"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, messages: List[OnceConversation]) -> None:
|
def update(self, messages: List[OnceConversation]) -> None:
|
||||||
pass
|
pass
|
||||||
@ -49,7 +44,7 @@ class BaseChatHistoryMemory(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def conv_info(self, conv_uid: str = None) -> None:
|
def conv_info(self, conv_uid: Optional[str] = None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -57,5 +52,7 @@ class BaseChatHistoryMemory(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def conv_list(cls, user_name: str = None) -> None:
|
def conv_list(
|
||||||
pass
|
user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""get user's conversation list"""
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from typing import Type
|
||||||
from .base import MemoryStoreType
|
from .base import MemoryStoreType
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
@ -32,5 +33,5 @@ class ChatHistory:
|
|||||||
chat_session_id
|
chat_session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_store_cls(self):
|
def get_store_cls(self) -> Type[BaseChatHistoryMemory]:
|
||||||
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
|
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
||||||
from sqlalchemy import UniqueConstraint
|
from sqlalchemy import UniqueConstraint
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ class ChatHistoryEntity(Base):
|
|||||||
messages = Column(
|
messages = Column(
|
||||||
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
|
Text(length=2**31 - 1), nullable=True, comment="Conversation details"
|
||||||
)
|
)
|
||||||
|
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||||
UniqueConstraint("conv_uid", name="uk_conversation")
|
UniqueConstraint("conv_uid", name="uk_conversation")
|
||||||
Index("idx_q_user", "user_name")
|
Index("idx_q_user", "user_name")
|
||||||
Index("idx_q_mode", "chat_mode")
|
Index("idx_q_mode", "chat_mode")
|
||||||
@ -48,11 +48,15 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
|||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_last_20(self, user_name: str = None):
|
def list_last_20(
|
||||||
|
self, user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||||
|
):
|
||||||
session = self.get_session()
|
session = self.get_session()
|
||||||
chat_history = session.query(ChatHistoryEntity)
|
chat_history = session.query(ChatHistoryEntity)
|
||||||
if user_name:
|
if user_name:
|
||||||
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
|
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
|
||||||
|
if sys_code:
|
||||||
|
chat_history = chat_history.filter(ChatHistoryEntity.sys_code == sys_code)
|
||||||
|
|
||||||
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
|
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import duckdb
|
import duckdb
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
@ -37,7 +37,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
if not result:
|
if not result:
|
||||||
# 如果表不存在,则创建新表
|
# 如果表不存在,则创建新表
|
||||||
self.connect.execute(
|
self.connect.execute(
|
||||||
"CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)"
|
"CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)"
|
||||||
)
|
)
|
||||||
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
||||||
|
|
||||||
@ -61,8 +61,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
try:
|
try:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)",
|
"INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
|
||||||
[self.chat_seesion_id, chat_mode, summary, user_name, ""],
|
[self.chat_seesion_id, chat_mode, summary, user_name, "", ""],
|
||||||
)
|
)
|
||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
@ -83,12 +83,13 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)",
|
"INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
|
||||||
[
|
[
|
||||||
self.chat_seesion_id,
|
self.chat_seesion_id,
|
||||||
once_message.chat_mode,
|
once_message.chat_mode,
|
||||||
once_message.get_user_conv().content,
|
once_message.get_user_conv().content,
|
||||||
"",
|
once_message.user_name,
|
||||||
|
once_message.sys_code,
|
||||||
json.dumps(conversations, ensure_ascii=False),
|
json.dumps(conversations, ensure_ascii=False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -149,17 +150,26 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def conv_list(cls, user_name: str = None) -> None:
|
def conv_list(
|
||||||
|
user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
if os.path.isfile(duckdb_path):
|
if os.path.isfile(duckdb_path):
|
||||||
cursor = duckdb.connect(duckdb_path).cursor()
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
|
query = "SELECT * FROM chat_history"
|
||||||
|
params = []
|
||||||
|
conditions = []
|
||||||
if user_name:
|
if user_name:
|
||||||
cursor.execute(
|
conditions.append("user_name = ?")
|
||||||
"SELECT * FROM chat_history where user_name=? order by id desc limit 20",
|
params.append(user_name)
|
||||||
[user_name],
|
if sys_code:
|
||||||
)
|
conditions.append("sys_code = ?")
|
||||||
else:
|
params.append(sys_code)
|
||||||
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
|
|
||||||
# 获取查询结果字段名
|
if conditions:
|
||||||
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
|
query += " ORDER BY id DESC LIMIT 20"
|
||||||
|
cursor.execute(query, params)
|
||||||
fields = [field[0] for field in cursor.description]
|
fields = [field[0] for field in cursor.description]
|
||||||
data = []
|
data = []
|
||||||
for row in cursor.fetchall():
|
for row in cursor.fetchall():
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
||||||
from sqlalchemy import UniqueConstraint
|
from sqlalchemy import UniqueConstraint
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -62,7 +62,8 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||||
chat_history.conv_uid = self.chat_seesion_id
|
chat_history.conv_uid = self.chat_seesion_id
|
||||||
chat_history.chat_mode = once_message.chat_mode
|
chat_history.chat_mode = once_message.chat_mode
|
||||||
chat_history.user_name = "default"
|
chat_history.user_name = once_message.user_name
|
||||||
|
chat_history.sys_code = once_message.sys_code
|
||||||
chat_history.summary = once_message.get_user_conv().content
|
chat_history.summary = once_message.get_user_conv().content
|
||||||
|
|
||||||
conversations.append(_conversation_to_dic(once_message))
|
conversations.append(_conversation_to_dic(once_message))
|
||||||
@ -92,9 +93,11 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def conv_list(cls, user_name: str = None) -> None:
|
def conv_list(
|
||||||
|
user_name: Optional[str] = None, sys_code: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
chat_history_dao = ChatHistoryDao()
|
chat_history_dao = ChatHistoryDao()
|
||||||
history_list = chat_history_dao.list_last_20()
|
history_list = chat_history_dao.list_last_20(user_name, sys_code)
|
||||||
result = []
|
result = []
|
||||||
for history in history_list:
|
for history in history_list:
|
||||||
result.append(history.__dict__)
|
result.append(history.__dict__)
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import shutil
|
import aiofiles
|
||||||
import logging
|
import logging
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
@ -17,7 +17,7 @@ from fastapi import (
|
|||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
import tempfile
|
import tempfile
|
||||||
from concurrent.futures import Executor
|
from concurrent.futures import Executor
|
||||||
|
|
||||||
@ -48,7 +48,11 @@ from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
|||||||
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
|
||||||
from pilot.model.base import FlatSupportedModel
|
from pilot.model.base import FlatSupportedModel
|
||||||
from pilot.utils.tracer import root_tracer, SpanType
|
from pilot.utils.tracer import root_tracer, SpanType
|
||||||
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
from pilot.utils.executor_utils import (
|
||||||
|
ExecutorFactory,
|
||||||
|
blocking_func_to_async,
|
||||||
|
DefaultExecutorFactory,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -68,9 +72,11 @@ def __get_conv_user_message(conversations: dict):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def __new_conversation(chat_mode, user_id) -> ConversationVo:
|
def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo:
|
||||||
unique_id = uuid.uuid1()
|
unique_id = uuid.uuid1()
|
||||||
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)
|
return ConversationVo(
|
||||||
|
conv_uid=str(unique_id), chat_mode=chat_mode, sys_code=sys_code
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_db_list():
|
def get_db_list():
|
||||||
@ -141,7 +147,9 @@ def get_worker_manager() -> WorkerManager:
|
|||||||
def get_executor() -> Executor:
|
def get_executor() -> Executor:
|
||||||
"""Get the global default executor"""
|
"""Get the global default executor"""
|
||||||
return CFG.SYSTEM_APP.get_component(
|
return CFG.SYSTEM_APP.get_component(
|
||||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
ComponentType.EXECUTOR_DEFAULT,
|
||||||
|
ExecutorFactory,
|
||||||
|
or_register_component=DefaultExecutorFactory,
|
||||||
).create()
|
).create()
|
||||||
|
|
||||||
|
|
||||||
@ -166,7 +174,6 @@ async def db_connect_delete(db_name: str = None):
|
|||||||
|
|
||||||
|
|
||||||
async def async_db_summary_embedding(db_name, db_type):
|
async def async_db_summary_embedding(db_name, db_type):
|
||||||
# 在这里执行需要异步运行的代码
|
|
||||||
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||||
db_summary_client.db_summary_embedding(db_name, db_type)
|
db_summary_client.db_summary_embedding(db_name, db_type)
|
||||||
|
|
||||||
@ -200,16 +207,21 @@ async def db_support_types():
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||||
async def dialogue_list(user_id: str = None):
|
async def dialogue_list(
|
||||||
|
user_name: str = None, user_id: str = None, sys_code: str = None
|
||||||
|
):
|
||||||
dialogues: List = []
|
dialogues: List = []
|
||||||
chat_history_service = ChatHistory()
|
chat_history_service = ChatHistory()
|
||||||
# TODO Change the synchronous call to the asynchronous call
|
# TODO Change the synchronous call to the asynchronous call
|
||||||
datas = chat_history_service.get_store_cls().conv_list(user_id)
|
user_name = user_name or user_id
|
||||||
|
datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code)
|
||||||
for item in datas:
|
for item in datas:
|
||||||
conv_uid = item.get("conv_uid")
|
conv_uid = item.get("conv_uid")
|
||||||
summary = item.get("summary")
|
summary = item.get("summary")
|
||||||
chat_mode = item.get("chat_mode")
|
chat_mode = item.get("chat_mode")
|
||||||
model_name = item.get("model_name", CFG.LLM_MODEL)
|
model_name = item.get("model_name", CFG.LLM_MODEL)
|
||||||
|
user_name = item.get("user_name")
|
||||||
|
sys_code = item.get("sys_code")
|
||||||
|
|
||||||
messages = json.loads(item.get("messages"))
|
messages = json.loads(item.get("messages"))
|
||||||
last_round = max(messages, key=lambda x: x["chat_order"])
|
last_round = max(messages, key=lambda x: x["chat_order"])
|
||||||
@ -223,6 +235,8 @@ async def dialogue_list(user_id: str = None):
|
|||||||
chat_mode=chat_mode,
|
chat_mode=chat_mode,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
select_param=select_param,
|
select_param=select_param,
|
||||||
|
user_name=user_name,
|
||||||
|
sys_code=sys_code,
|
||||||
)
|
)
|
||||||
dialogues.append(conv_vo)
|
dialogues.append(conv_vo)
|
||||||
|
|
||||||
@ -254,9 +268,14 @@ async def dialogue_scenes():
|
|||||||
|
|
||||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(
|
async def dialogue_new(
|
||||||
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
|
chat_mode: str = ChatScene.ChatNormal.value(),
|
||||||
|
user_name: str = None,
|
||||||
|
# TODO remove user id
|
||||||
|
user_id: str = None,
|
||||||
|
sys_code: str = None,
|
||||||
):
|
):
|
||||||
conv_vo = __new_conversation(chat_mode, user_id)
|
user_name = user_name or user_id
|
||||||
|
conv_vo = __new_conversation(chat_mode, user_name, sys_code)
|
||||||
return Result.succ(conv_vo)
|
return Result.succ(conv_vo)
|
||||||
|
|
||||||
|
|
||||||
@ -280,40 +299,40 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
|||||||
|
|
||||||
@router.post("/v1/chat/mode/params/file/load")
|
@router.post("/v1/chat/mode/params/file/load")
|
||||||
async def params_load(
|
async def params_load(
|
||||||
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
|
conv_uid: str,
|
||||||
|
chat_mode: str,
|
||||||
|
model_name: str,
|
||||||
|
user_name: Optional[str] = None,
|
||||||
|
sys_code: Optional[str] = None,
|
||||||
|
doc_file: UploadFile = File(...),
|
||||||
):
|
):
|
||||||
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
|
||||||
try:
|
try:
|
||||||
if doc_file:
|
if doc_file:
|
||||||
## file save
|
# Save the uploaded file
|
||||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
|
upload_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
|
||||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
|
os.makedirs(upload_dir, exist_ok=True)
|
||||||
# We can not move temp file in windows system when we open file in context of `with`
|
upload_path = os.path.join(upload_dir, doc_file.filename)
|
||||||
tmp_fd, tmp_path = tempfile.mkstemp(
|
async with aiofiles.open(upload_path, "wb") as f:
|
||||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
|
await f.write(await doc_file.read())
|
||||||
)
|
|
||||||
# TODO Use noblocking file save with aiofiles
|
# Prepare the chat
|
||||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
|
||||||
tmp.write(await doc_file.read())
|
|
||||||
shutil.move(
|
|
||||||
tmp_path,
|
|
||||||
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename),
|
|
||||||
)
|
|
||||||
## chat prepare
|
|
||||||
dialogue = ConversationVo(
|
dialogue = ConversationVo(
|
||||||
conv_uid=conv_uid,
|
conv_uid=conv_uid,
|
||||||
chat_mode=chat_mode,
|
chat_mode=chat_mode,
|
||||||
select_param=doc_file.filename,
|
select_param=doc_file.filename,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
user_name=user_name,
|
||||||
|
sys_code=sys_code,
|
||||||
)
|
)
|
||||||
chat: BaseChat = await get_chat_instance(dialogue)
|
chat: BaseChat = await get_chat_instance(dialogue)
|
||||||
resp = await chat.prepare()
|
resp = await chat.prepare()
|
||||||
|
|
||||||
### refresh messages
|
# Refresh messages
|
||||||
return Result.succ(get_hist_messages(conv_uid))
|
return Result.succ(get_hist_messages(conv_uid))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("excel load error!", e)
|
logger.error("excel load error!", e)
|
||||||
return Result.failed(code="E000X", msg=f"File Load Error {e}")
|
return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/dialogue/delete")
|
@router.post("/v1/chat/dialogue/delete")
|
||||||
@ -354,7 +373,9 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
|||||||
if not dialogue.chat_mode:
|
if not dialogue.chat_mode:
|
||||||
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
||||||
if not dialogue.conv_uid:
|
if not dialogue.conv_uid:
|
||||||
conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name)
|
conv_vo = __new_conversation(
|
||||||
|
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
|
||||||
|
)
|
||||||
dialogue.conv_uid = conv_vo.conv_uid
|
dialogue.conv_uid = conv_vo.conv_uid
|
||||||
|
|
||||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||||
@ -364,13 +385,12 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
|||||||
|
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": dialogue.conv_uid,
|
"chat_session_id": dialogue.conv_uid,
|
||||||
|
"user_name": dialogue.user_name,
|
||||||
|
"sys_code": dialogue.sys_code,
|
||||||
"current_user_input": dialogue.user_input,
|
"current_user_input": dialogue.user_input,
|
||||||
"select_param": dialogue.select_param,
|
"select_param": dialogue.select_param,
|
||||||
"model_name": dialogue.model_name,
|
"model_name": dialogue.model_name,
|
||||||
}
|
}
|
||||||
# chat: BaseChat = CHAT_FACTORY.get_implementation(
|
|
||||||
# dialogue.chat_mode, **{"chat_param": chat_param}
|
|
||||||
# )
|
|
||||||
chat: BaseChat = await blocking_func_to_async(
|
chat: BaseChat = await blocking_func_to_async(
|
||||||
get_executor(),
|
get_executor(),
|
||||||
CHAT_FACTORY.get_implementation,
|
CHAT_FACTORY.get_implementation,
|
||||||
@ -401,8 +421,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
|
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
|
||||||
):
|
):
|
||||||
chat: BaseChat = await get_chat_instance(dialogue)
|
chat: BaseChat = await get_chat_instance(dialogue)
|
||||||
# background_tasks = BackgroundTasks()
|
|
||||||
# background_tasks.add_task(release_model_semaphore)
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "text/event-stream",
|
"Content-Type": "text/event-stream",
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
|
@ -66,6 +66,8 @@ class ConversationVo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
incremental: bool = False
|
incremental: bool = False
|
||||||
|
|
||||||
|
sys_code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class MessageVo(BaseModel):
|
class MessageVo(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -78,7 +78,9 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(
|
self.current_message: OnceConversation = OnceConversation(
|
||||||
self.chat_mode.value()
|
self.chat_mode.value(),
|
||||||
|
user_name=chat_param.get("user_name"),
|
||||||
|
sys_code=chat_param.get("sys_code"),
|
||||||
)
|
)
|
||||||
self.current_message.model_name = self.llm_model
|
self.current_message.model_name = self.llm_model
|
||||||
if chat_param["select_param"]:
|
if chat_param["select_param"]:
|
||||||
@ -171,7 +173,6 @@ class BaseChat(ABC):
|
|||||||
"messages": llm_messages,
|
"messages": llm_messages,
|
||||||
"temperature": float(self.prompt_template.temperature),
|
"temperature": float(self.prompt_template.temperature),
|
||||||
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||||
# "stop": self.prompt_template.sep,
|
|
||||||
"echo": self.llm_echo,
|
"echo": self.llm_echo,
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
|
@ -18,7 +18,7 @@ class OnceConversation:
|
|||||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chat_mode):
|
def __init__(self, chat_mode, user_name: str = None, sys_code: str = None):
|
||||||
self.chat_mode: str = chat_mode
|
self.chat_mode: str = chat_mode
|
||||||
self.messages: List[BaseMessage] = []
|
self.messages: List[BaseMessage] = []
|
||||||
self.start_date: str = ""
|
self.start_date: str = ""
|
||||||
@ -28,6 +28,8 @@ class OnceConversation:
|
|||||||
self.param_value: str = ""
|
self.param_value: str = ""
|
||||||
self.cost: int = 0
|
self.cost: int = 0
|
||||||
self.tokens: int = 0
|
self.tokens: int = 0
|
||||||
|
self.user_name: str = user_name
|
||||||
|
self.sys_code: str = sys_code
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_user_message(self, message: str) -> None:
|
||||||
"""Add a user message to the store"""
|
"""Add a user message to the store"""
|
||||||
@ -113,6 +115,8 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
|||||||
"messages": messages_to_dict(once.messages),
|
"messages": messages_to_dict(once.messages),
|
||||||
"param_type": once.param_type,
|
"param_type": once.param_type,
|
||||||
"param_value": once.param_value,
|
"param_value": once.param_value,
|
||||||
|
"user_name": once.user_name,
|
||||||
|
"sys_code": once.sys_code,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -121,7 +125,9 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
|||||||
|
|
||||||
|
|
||||||
def conversation_from_dict(once: dict) -> OnceConversation:
|
def conversation_from_dict(once: dict) -> OnceConversation:
|
||||||
conversation = OnceConversation()
|
conversation = OnceConversation(
|
||||||
|
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
|
||||||
|
)
|
||||||
conversation.cost = once.get("cost", 0)
|
conversation.cost = once.get("cost", 0)
|
||||||
conversation.chat_mode = once.get("chat_mode", "chat_normal")
|
conversation.chat_mode = once.get("chat_mode", "chat_normal")
|
||||||
conversation.tokens = once.get("tokens", 0)
|
conversation.tokens = once.get("tokens", 0)
|
||||||
|
@ -29,6 +29,7 @@ class PromptManageEntity(Base):
|
|||||||
prompt_name = Column(String(512))
|
prompt_name = Column(String(512))
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
user_name = Column(String(128))
|
user_name = Column(String(128))
|
||||||
|
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||||
gmt_created = Column(DateTime)
|
gmt_created = Column(DateTime)
|
||||||
gmt_modified = Column(DateTime)
|
gmt_modified = Column(DateTime)
|
||||||
|
|
||||||
|
147
setup.py
147
setup.py
@ -1,5 +1,4 @@
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple, Optional, Callable
|
||||||
|
|
||||||
import setuptools
|
import setuptools
|
||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -10,6 +9,7 @@ from urllib.parse import urlparse, quote
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from setuptools import find_packages
|
from setuptools import find_packages
|
||||||
|
import functools
|
||||||
|
|
||||||
with open("README.md", mode="r", encoding="utf-8") as fh:
|
with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
@ -34,8 +34,15 @@ def parse_requirements(file_name: str) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
def get_latest_version(package_name: str, index_url: str, default_version: str):
|
||||||
|
python_command = shutil.which("python")
|
||||||
|
if not python_command:
|
||||||
|
python_command = shutil.which("python3")
|
||||||
|
if not python_command:
|
||||||
|
print("Python command not found.")
|
||||||
|
return default_version
|
||||||
|
|
||||||
command = [
|
command = [
|
||||||
"python",
|
python_command,
|
||||||
"-m",
|
"-m",
|
||||||
"pip",
|
"pip",
|
||||||
"index",
|
"index",
|
||||||
@ -125,6 +132,7 @@ class OSType(Enum):
|
|||||||
OTHER = "other"
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
|
def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
os_type = OSType.OTHER
|
os_type = OSType.OTHER
|
||||||
@ -206,6 +214,57 @@ def get_cuda_version() -> str:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_wheels(
|
||||||
|
pkg_name: str,
|
||||||
|
pkg_version: str,
|
||||||
|
base_url: str = None,
|
||||||
|
base_url_func: Callable[[str, str, str], str] = None,
|
||||||
|
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
|
||||||
|
supported_cuda_versions: List[str] = ["11.7", "11.8"],
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Build the URL for the package wheel file based on the package name, version, and CUDA version.
|
||||||
|
Args:
|
||||||
|
pkg_name (str): The name of the package.
|
||||||
|
pkg_version (str): The version of the package.
|
||||||
|
base_url (str): The base URL for downloading the package.
|
||||||
|
base_url_func (Callable): A function to generate the base URL.
|
||||||
|
pkg_file_func (Callable): build package file function.
|
||||||
|
function params: pkg_name, pkg_version, cuda_version, py_version, OSType
|
||||||
|
supported_cuda_versions (List[str]): The list of supported CUDA versions.
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The URL for the package wheel file.
|
||||||
|
"""
|
||||||
|
os_type, _ = get_cpu_avx_support()
|
||||||
|
cuda_version = get_cuda_version()
|
||||||
|
py_version = platform.python_version()
|
||||||
|
py_version = "cp" + "".join(py_version.split(".")[0:2])
|
||||||
|
if os_type == OSType.DARWIN or not cuda_version:
|
||||||
|
return None
|
||||||
|
if cuda_version not in supported_cuda_versions:
|
||||||
|
print(
|
||||||
|
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
|
||||||
|
)
|
||||||
|
cuda_version = supported_cuda_versions[-1]
|
||||||
|
|
||||||
|
cuda_version = "cu" + cuda_version.replace(".", "")
|
||||||
|
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
||||||
|
if base_url_func:
|
||||||
|
base_url = base_url_func(pkg_version, cuda_version, py_version)
|
||||||
|
if base_url and base_url.endswith("/"):
|
||||||
|
base_url = base_url[:-1]
|
||||||
|
if pkg_file_func:
|
||||||
|
full_pkg_file = pkg_file_func(
|
||||||
|
pkg_name, pkg_version, cuda_version, py_version, os_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
full_pkg_file = f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||||
|
if not base_url:
|
||||||
|
return full_pkg_file
|
||||||
|
else:
|
||||||
|
return f"{base_url}/{full_pkg_file}"
|
||||||
|
|
||||||
|
|
||||||
def torch_requires(
|
def torch_requires(
|
||||||
torch_version: str = "2.0.1",
|
torch_version: str = "2.0.1",
|
||||||
torchvision_version: str = "0.15.2",
|
torchvision_version: str = "0.15.2",
|
||||||
@ -222,16 +281,20 @@ def torch_requires(
|
|||||||
cuda_version = get_cuda_version()
|
cuda_version = get_cuda_version()
|
||||||
if cuda_version:
|
if cuda_version:
|
||||||
supported_versions = ["11.7", "11.8"]
|
supported_versions = ["11.7", "11.8"]
|
||||||
if cuda_version not in supported_versions:
|
# torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||||
print(
|
# torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
||||||
f"PyTorch version {torch_version} supported cuda version: {supported_versions}, replace to {supported_versions[-1]}"
|
torch_url = _build_wheels(
|
||||||
)
|
"torch",
|
||||||
cuda_version = supported_versions[-1]
|
torch_version,
|
||||||
cuda_version = "cu" + cuda_version.replace(".", "")
|
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
||||||
py_version = "cp310"
|
supported_cuda_versions=supported_versions,
|
||||||
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
|
)
|
||||||
torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
torchvision_url = _build_wheels(
|
||||||
torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
|
"torchvision",
|
||||||
|
torch_version,
|
||||||
|
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
|
||||||
|
supported_cuda_versions=supported_versions,
|
||||||
|
)
|
||||||
torch_url_cached = cache_package(
|
torch_url_cached = cache_package(
|
||||||
torch_url, "torch", os_type == OSType.WINDOWS
|
torch_url, "torch", os_type == OSType.WINDOWS
|
||||||
)
|
)
|
||||||
@ -327,6 +390,7 @@ def core_requires():
|
|||||||
"xlrd==2.0.1",
|
"xlrd==2.0.1",
|
||||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||||
"pympler",
|
"pympler",
|
||||||
|
"aiofiles",
|
||||||
]
|
]
|
||||||
if BUILD_FROM_SOURCE:
|
if BUILD_FROM_SOURCE:
|
||||||
setup_spec.extras["framework"].append(
|
setup_spec.extras["framework"].append(
|
||||||
@ -360,6 +424,41 @@ def llama_cpp_requires():
|
|||||||
llama_cpp_python_cuda_requires()
|
llama_cpp_python_cuda_requires()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_autoawq_requires() -> Optional[str]:
|
||||||
|
os_type, _ = get_cpu_avx_support()
|
||||||
|
if os_type == OSType.DARWIN:
|
||||||
|
return None
|
||||||
|
auto_gptq_version = get_latest_version(
|
||||||
|
"auto-gptq", "https://huggingface.github.io/autogptq-index/whl/cu118/", "0.5.1"
|
||||||
|
)
|
||||||
|
# eg. 0.5.1+cu118
|
||||||
|
auto_gptq_version = auto_gptq_version.split("+")[0]
|
||||||
|
|
||||||
|
def pkg_file_func(pkg_name, pkg_version, cuda_version, py_version, os_type):
|
||||||
|
pkg_name = pkg_name.replace("-", "_")
|
||||||
|
if os_type == OSType.DARWIN:
|
||||||
|
return None
|
||||||
|
os_pkg_name = (
|
||||||
|
"manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
|
||||||
|
if os_type == OSType.LINUX
|
||||||
|
else "win_amd64.whl"
|
||||||
|
)
|
||||||
|
return f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}"
|
||||||
|
|
||||||
|
auto_gptq_url = _build_wheels(
|
||||||
|
"auto-gptq",
|
||||||
|
auto_gptq_version,
|
||||||
|
base_url_func=lambda v, x, y: f"https://huggingface.github.io/autogptq-index/whl/{x}/auto-gptq",
|
||||||
|
pkg_file_func=pkg_file_func,
|
||||||
|
supported_cuda_versions=["11.8"],
|
||||||
|
)
|
||||||
|
if auto_gptq_url:
|
||||||
|
print(f"Install auto-gptq from {auto_gptq_url}")
|
||||||
|
return f"auto-gptq @ {auto_gptq_url}"
|
||||||
|
else:
|
||||||
|
"auto-gptq"
|
||||||
|
|
||||||
|
|
||||||
def quantization_requires():
|
def quantization_requires():
|
||||||
pkgs = []
|
pkgs = []
|
||||||
os_type, _ = get_cpu_avx_support()
|
os_type, _ = get_cpu_avx_support()
|
||||||
@ -379,6 +478,28 @@ def quantization_requires():
|
|||||||
print(pkgs)
|
print(pkgs)
|
||||||
# For chatglm2-6b-int4
|
# For chatglm2-6b-int4
|
||||||
pkgs += ["cpm_kernels"]
|
pkgs += ["cpm_kernels"]
|
||||||
|
|
||||||
|
# Since transformers 4.35.0, the GPT-Q/AWQ model can be loaded using AutoModelForCausalLM.
|
||||||
|
# autoawq requirements:
|
||||||
|
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
|
||||||
|
# 2. CUDA Toolkit 11.8 and later.
|
||||||
|
autoawq_url = _build_wheels(
|
||||||
|
"autoawq",
|
||||||
|
"0.1.7",
|
||||||
|
base_url_func=lambda v, x, y: f"https://github.com/casper-hansen/AutoAWQ/releases/download/v{v}",
|
||||||
|
supported_cuda_versions=["11.8"],
|
||||||
|
)
|
||||||
|
if autoawq_url:
|
||||||
|
print(f"Install autoawq from {autoawq_url}")
|
||||||
|
pkgs.append(f"autoawq @ {autoawq_url}")
|
||||||
|
else:
|
||||||
|
pkgs.append("autoawq")
|
||||||
|
|
||||||
|
auto_gptq_pkg = _build_autoawq_requires()
|
||||||
|
if auto_gptq_pkg:
|
||||||
|
pkgs.append(auto_gptq_pkg)
|
||||||
|
pkgs.append("optimum")
|
||||||
|
|
||||||
setup_spec.extras["quantization"] = pkgs
|
setup_spec.extras["quantization"] = pkgs
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user