feat(core): MTB supports multi-user and multi-system fields (#854)

This commit is contained in:
FangYin Cheng 2023-11-27 20:17:56 +08:00 committed by GitHub
parent 20aac6340b
commit eeff46487d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 262 additions and 83 deletions

View File

@ -66,6 +66,7 @@ CREATE TABLE `connect_config` (
`db_user` varchar(255) DEFAULT NULL COMMENT 'db user', `db_user` varchar(255) DEFAULT NULL COMMENT 'db user',
`db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password',
`comment` text COMMENT 'db comment', `comment` text COMMENT 'db comment',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`), PRIMARY KEY (`id`),
UNIQUE KEY `uk_db` (`db_name`), UNIQUE KEY `uk_db` (`db_name`),
KEY `idx_q_db_type` (`db_type`) KEY `idx_q_db_type` (`db_type`)
@ -78,6 +79,7 @@ CREATE TABLE `chat_history` (
`summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', `summary` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary',
`user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor',
`messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
PRIMARY KEY (`id`) PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; ) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history';
@ -110,6 +112,7 @@ CREATE TABLE `my_plugin` (
`version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version',
`use_count` int DEFAULT NULL COMMENT 'plugin total use count', `use_count` int DEFAULT NULL COMMENT 'plugin total use count',
`succ_count` int DEFAULT NULL COMMENT 'plugin total success count', `succ_count` int DEFAULT NULL COMMENT 'plugin total success count',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time',
PRIMARY KEY (`id`), PRIMARY KEY (`id`),
UNIQUE KEY `name` (`name`) UNIQUE KEY `name` (`name`)
@ -141,6 +144,7 @@ CREATE TABLE `prompt_manage` (
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`), PRIMARY KEY (`id`),

View File

@ -32,6 +32,7 @@ class MyPluginEntity(Base):
succ_count = Column( succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count" Integer, nullable=True, default=0, comment="plugin total success count"
) )
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column( gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin install time" DateTime, default=datetime.utcnow, comment="plugin install time"
) )
@ -58,6 +59,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
version=engity.version, version=engity.version,
use_count=engity.use_count or 0, use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0, succ_count=engity.succ_count or 0,
sys_code=engity.sys_code,
gmt_created=datetime.now(), gmt_created=datetime.now(),
) )
session.add(my_plugin) session.add(my_plugin)
@ -107,6 +109,8 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code) my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None: if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name) my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc()) my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size) my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
@ -133,6 +137,8 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code) my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None: if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name) my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
count = my_plugins.scalar() count = my_plugins.scalar()
session.close() session.close()
return count return count

View File

@ -128,6 +128,10 @@ LLM_MODEL_CONFIG = {
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"), "xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
# https://huggingface.co/01-ai/Yi-34B-Chat # https://huggingface.co/01-ai/Yi-34B-Chat
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"), "yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
# https://huggingface.co/01-ai/Yi-34B-Chat-8bits
"yi-34b-chat-8bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-8bits"),
# https://huggingface.co/01-ai/Yi-34B-Chat-4bits
"yi-34b-chat-4bits": os.path.join(MODEL_PATH, "Yi-34B-Chat-4bits"),
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"), "yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
} }

View File

@ -24,6 +24,7 @@ class ConnectConfigEntity(Base):
db_user = Column(String(255), nullable=True, comment="db user") db_user = Column(String(255), nullable=True, comment="db user")
db_pwd = Column(String(255), nullable=True, comment="db password") db_pwd = Column(String(255), nullable=True, comment="db password")
comment = Column(Text, nullable=True, comment="db comment") comment = Column(Text, nullable=True, comment="db comment")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
__table_args__ = ( __table_args__ = (
UniqueConstraint("db_name", name="uk_db"), UniqueConstraint("db_name", name="uk_db"),

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import List, Optional, Dict
from enum import Enum from enum import Enum
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
@ -35,11 +35,6 @@ class BaseChatHistoryMemory(ABC):
# def clear(self) -> None: # def clear(self) -> None:
# """Clear session memory from the local file""" # """Clear session memory from the local file"""
@abstractmethod
def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list"""
pass
@abstractmethod @abstractmethod
def update(self, messages: List[OnceConversation]) -> None: def update(self, messages: List[OnceConversation]) -> None:
pass pass
@ -49,7 +44,7 @@ class BaseChatHistoryMemory(ABC):
pass pass
@abstractmethod @abstractmethod
def conv_info(self, conv_uid: str = None) -> None: def conv_info(self, conv_uid: Optional[str] = None) -> None:
pass pass
@abstractmethod @abstractmethod
@ -57,5 +52,7 @@ class BaseChatHistoryMemory(ABC):
pass pass
@staticmethod @staticmethod
def conv_list(cls, user_name: str = None) -> None: def conv_list(
pass user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:
"""get user's conversation list"""

View File

@ -1,3 +1,4 @@
from typing import Type
from .base import MemoryStoreType from .base import MemoryStoreType
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
@ -32,5 +33,5 @@ class ChatHistory:
chat_session_id chat_session_id
) )
def get_store_cls(self): def get_store_cls(self) -> Type[BaseChatHistoryMemory]:
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE) return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
@ -32,7 +32,7 @@ class ChatHistoryEntity(Base):
messages = Column( messages = Column(
Text(length=2**31 - 1), nullable=True, comment="Conversation details" Text(length=2**31 - 1), nullable=True, comment="Conversation details"
) )
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
UniqueConstraint("conv_uid", name="uk_conversation") UniqueConstraint("conv_uid", name="uk_conversation")
Index("idx_q_user", "user_name") Index("idx_q_user", "user_name")
Index("idx_q_mode", "chat_mode") Index("idx_q_mode", "chat_mode")
@ -48,11 +48,15 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
session=session, session=session,
) )
def list_last_20(self, user_name: str = None): def list_last_20(
self, user_name: Optional[str] = None, sys_code: Optional[str] = None
):
session = self.get_session() session = self.get_session()
chat_history = session.query(ChatHistoryEntity) chat_history = session.query(ChatHistoryEntity)
if user_name: if user_name:
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name) chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
if sys_code:
chat_history = chat_history.filter(ChatHistoryEntity.sys_code == sys_code)
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc()) chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())

View File

@ -1,7 +1,7 @@
import json import json
import os import os
import duckdb import duckdb
from typing import List from typing import List, Dict, Optional
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
@ -37,7 +37,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if not result: if not result:
# 如果表不存在,则创建新表 # 如果表不存在,则创建新表
self.connect.execute( self.connect.execute(
"CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)" "CREATE TABLE chat_history (id integer primary key, conv_uid VARCHAR(100) UNIQUE, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), sys_code VARCHAR(128), messages TEXT)"
) )
self.connect.execute("CREATE SEQUENCE seq_id START 1;") self.connect.execute("CREATE SEQUENCE seq_id START 1;")
@ -61,8 +61,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
try: try:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute(
"INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)", "INSERT INTO chat_history(id, conv_uid, chat_mode summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
[self.chat_seesion_id, chat_mode, summary, user_name, ""], [self.chat_seesion_id, chat_mode, summary, user_name, "", ""],
) )
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
@ -83,12 +83,13 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
) )
else: else:
cursor.execute( cursor.execute(
"INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, messages)VALUES(nextval('seq_id'),?,?,?,?,?)", "INSERT INTO chat_history(id, conv_uid, chat_mode, summary, user_name, sys_code, messages)VALUES(nextval('seq_id'),?,?,?,?,?,?)",
[ [
self.chat_seesion_id, self.chat_seesion_id,
once_message.chat_mode, once_message.chat_mode,
once_message.get_user_conv().content, once_message.get_user_conv().content,
"", once_message.user_name,
once_message.sys_code,
json.dumps(conversations, ensure_ascii=False), json.dumps(conversations, ensure_ascii=False),
], ],
) )
@ -149,17 +150,26 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return None return None
@staticmethod @staticmethod
def conv_list(cls, user_name: str = None) -> None: def conv_list(
user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
query = "SELECT * FROM chat_history"
params = []
conditions = []
if user_name: if user_name:
cursor.execute( conditions.append("user_name = ?")
"SELECT * FROM chat_history where user_name=? order by id desc limit 20", params.append(user_name)
[user_name], if sys_code:
) conditions.append("sys_code = ?")
else: params.append(sys_code)
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")
# 获取查询结果字段名 if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY id DESC LIMIT 20"
cursor.execute(query, params)
fields = [field[0] for field in cursor.description] fields = [field[0] for field in cursor.description]
data = [] data = []
for row in cursor.fetchall(): for row in cursor.fetchall():

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import List from typing import List, Dict, Optional
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
from pilot.configs.config import Config from pilot.configs.config import Config
@ -62,7 +62,8 @@ class DbHistoryMemory(BaseChatHistoryMemory):
chat_history: ChatHistoryEntity = ChatHistoryEntity() chat_history: ChatHistoryEntity = ChatHistoryEntity()
chat_history.conv_uid = self.chat_seesion_id chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode chat_history.chat_mode = once_message.chat_mode
chat_history.user_name = "default" chat_history.user_name = once_message.user_name
chat_history.sys_code = once_message.sys_code
chat_history.summary = once_message.get_user_conv().content chat_history.summary = once_message.get_user_conv().content
conversations.append(_conversation_to_dic(once_message)) conversations.append(_conversation_to_dic(once_message))
@ -92,9 +93,11 @@ class DbHistoryMemory(BaseChatHistoryMemory):
return [] return []
@staticmethod @staticmethod
def conv_list(cls, user_name: str = None) -> None: def conv_list(
user_name: Optional[str] = None, sys_code: Optional[str] = None
) -> List[Dict]:
chat_history_dao = ChatHistoryDao() chat_history_dao = ChatHistoryDao()
history_list = chat_history_dao.list_last_20() history_list = chat_history_dao.list_last_20(user_name, sys_code)
result = [] result = []
for history in history_list: for history in history_list:
result.append(history.__dict__) result.append(history.__dict__)

View File

@ -2,7 +2,7 @@ import json
import uuid import uuid
import asyncio import asyncio
import os import os
import shutil import aiofiles
import logging import logging
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
@ -17,7 +17,7 @@ from fastapi import (
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from typing import List from typing import List, Optional
import tempfile import tempfile
from concurrent.futures import Executor from concurrent.futures import Executor
@ -48,7 +48,11 @@ from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory from pilot.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory
from pilot.model.base import FlatSupportedModel from pilot.model.base import FlatSupportedModel
from pilot.utils.tracer import root_tracer, SpanType from pilot.utils.tracer import root_tracer, SpanType
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async from pilot.utils.executor_utils import (
ExecutorFactory,
blocking_func_to_async,
DefaultExecutorFactory,
)
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
@ -68,9 +72,11 @@ def __get_conv_user_message(conversations: dict):
return "" return ""
def __new_conversation(chat_mode, user_id) -> ConversationVo: def __new_conversation(chat_mode, user_name: str, sys_code: str) -> ConversationVo:
unique_id = uuid.uuid1() unique_id = uuid.uuid1()
return ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode) return ConversationVo(
conv_uid=str(unique_id), chat_mode=chat_mode, sys_code=sys_code
)
def get_db_list(): def get_db_list():
@ -141,7 +147,9 @@ def get_worker_manager() -> WorkerManager:
def get_executor() -> Executor: def get_executor() -> Executor:
"""Get the global default executor""" """Get the global default executor"""
return CFG.SYSTEM_APP.get_component( return CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ComponentType.EXECUTOR_DEFAULT,
ExecutorFactory,
or_register_component=DefaultExecutorFactory,
).create() ).create()
@ -166,7 +174,6 @@ async def db_connect_delete(db_name: str = None):
async def async_db_summary_embedding(db_name, db_type): async def async_db_summary_embedding(db_name, db_type):
# 在这里执行需要异步运行的代码
db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP) db_summary_client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
db_summary_client.db_summary_embedding(db_name, db_type) db_summary_client.db_summary_embedding(db_name, db_type)
@ -200,16 +207,21 @@ async def db_support_types():
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(user_id: str = None): async def dialogue_list(
user_name: str = None, user_id: str = None, sys_code: str = None
):
dialogues: List = [] dialogues: List = []
chat_history_service = ChatHistory() chat_history_service = ChatHistory()
# TODO Change the synchronous call to the asynchronous call # TODO Change the synchronous call to the asynchronous call
datas = chat_history_service.get_store_cls().conv_list(user_id) user_name = user_name or user_id
datas = chat_history_service.get_store_cls().conv_list(user_name, sys_code)
for item in datas: for item in datas:
conv_uid = item.get("conv_uid") conv_uid = item.get("conv_uid")
summary = item.get("summary") summary = item.get("summary")
chat_mode = item.get("chat_mode") chat_mode = item.get("chat_mode")
model_name = item.get("model_name", CFG.LLM_MODEL) model_name = item.get("model_name", CFG.LLM_MODEL)
user_name = item.get("user_name")
sys_code = item.get("sys_code")
messages = json.loads(item.get("messages")) messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x["chat_order"]) last_round = max(messages, key=lambda x: x["chat_order"])
@ -223,6 +235,8 @@ async def dialogue_list(user_id: str = None):
chat_mode=chat_mode, chat_mode=chat_mode,
model_name=model_name, model_name=model_name,
select_param=select_param, select_param=select_param,
user_name=user_name,
sys_code=sys_code,
) )
dialogues.append(conv_vo) dialogues.append(conv_vo)
@ -254,9 +268,14 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new( async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None chat_mode: str = ChatScene.ChatNormal.value(),
user_name: str = None,
# TODO remove user id
user_id: str = None,
sys_code: str = None,
): ):
conv_vo = __new_conversation(chat_mode, user_id) user_name = user_name or user_id
conv_vo = __new_conversation(chat_mode, user_name, sys_code)
return Result.succ(conv_vo) return Result.succ(conv_vo)
@ -280,40 +299,40 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
@router.post("/v1/chat/mode/params/file/load") @router.post("/v1/chat/mode/params/file/load")
async def params_load( async def params_load(
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...) conv_uid: str,
chat_mode: str,
model_name: str,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
doc_file: UploadFile = File(...),
): ):
print(f"params_load: {conv_uid},{chat_mode},{model_name}") print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try: try:
if doc_file: if doc_file:
## file save # Save the uploaded file
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)): upload_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)) os.makedirs(upload_dir, exist_ok=True)
# We can not move temp file in windows system when we open file in context of `with` upload_path = os.path.join(upload_dir, doc_file.filename)
tmp_fd, tmp_path = tempfile.mkstemp( async with aiofiles.open(upload_path, "wb") as f:
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode) await f.write(await doc_file.read())
)
# TODO Use noblocking file save with aiofiles # Prepare the chat
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename),
)
## chat prepare
dialogue = ConversationVo( dialogue = ConversationVo(
conv_uid=conv_uid, conv_uid=conv_uid,
chat_mode=chat_mode, chat_mode=chat_mode,
select_param=doc_file.filename, select_param=doc_file.filename,
model_name=model_name, model_name=model_name,
user_name=user_name,
sys_code=sys_code,
) )
chat: BaseChat = await get_chat_instance(dialogue) chat: BaseChat = await get_chat_instance(dialogue)
resp = await chat.prepare() resp = await chat.prepare()
### refresh messages # Refresh messages
return Result.succ(get_hist_messages(conv_uid)) return Result.succ(get_hist_messages(conv_uid))
except Exception as e: except Exception as e:
logger.error("excel load error!", e) logger.error("excel load error!", e)
return Result.failed(code="E000X", msg=f"File Load Error {e}") return Result.failed(code="E000X", msg=f"File Load Error {str(e)}")
@router.post("/v1/chat/dialogue/delete") @router.post("/v1/chat/dialogue/delete")
@ -354,7 +373,9 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
if not dialogue.chat_mode: if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value() dialogue.chat_mode = ChatScene.ChatNormal.value()
if not dialogue.conv_uid: if not dialogue.conv_uid:
conv_vo = __new_conversation(dialogue.chat_mode, dialogue.user_name) conv_vo = __new_conversation(
dialogue.chat_mode, dialogue.user_name, dialogue.sys_code
)
dialogue.conv_uid = conv_vo.conv_uid dialogue.conv_uid = conv_vo.conv_uid
if not ChatScene.is_valid_mode(dialogue.chat_mode): if not ChatScene.is_valid_mode(dialogue.chat_mode):
@ -364,13 +385,12 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
chat_param = { chat_param = {
"chat_session_id": dialogue.conv_uid, "chat_session_id": dialogue.conv_uid,
"user_name": dialogue.user_name,
"sys_code": dialogue.sys_code,
"current_user_input": dialogue.user_input, "current_user_input": dialogue.user_input,
"select_param": dialogue.select_param, "select_param": dialogue.select_param,
"model_name": dialogue.model_name, "model_name": dialogue.model_name,
} }
# chat: BaseChat = CHAT_FACTORY.get_implementation(
# dialogue.chat_mode, **{"chat_param": chat_param}
# )
chat: BaseChat = await blocking_func_to_async( chat: BaseChat = await blocking_func_to_async(
get_executor(), get_executor(),
CHAT_FACTORY.get_implementation, CHAT_FACTORY.get_implementation,
@ -401,8 +421,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
"get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict() "get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict()
): ):
chat: BaseChat = await get_chat_instance(dialogue) chat: BaseChat = await get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
headers = { headers = {
"Content-Type": "text/event-stream", "Content-Type": "text/event-stream",
"Cache-Control": "no-cache", "Cache-Control": "no-cache",

View File

@ -66,6 +66,8 @@ class ConversationVo(BaseModel):
""" """
incremental: bool = False incremental: bool = False
sys_code: Optional[str] = None
class MessageVo(BaseModel): class MessageVo(BaseModel):
""" """

View File

@ -78,7 +78,9 @@ class BaseChat(ABC):
self.history_message: List[OnceConversation] = self.memory.messages() self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation( self.current_message: OnceConversation = OnceConversation(
self.chat_mode.value() self.chat_mode.value(),
user_name=chat_param.get("user_name"),
sys_code=chat_param.get("sys_code"),
) )
self.current_message.model_name = self.llm_model self.current_message.model_name = self.llm_model
if chat_param["select_param"]: if chat_param["select_param"]:
@ -171,7 +173,6 @@ class BaseChat(ABC):
"messages": llm_messages, "messages": llm_messages,
"temperature": float(self.prompt_template.temperature), "temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens), "max_new_tokens": int(self.prompt_template.max_new_tokens),
# "stop": self.prompt_template.sep,
"echo": self.llm_echo, "echo": self.llm_echo,
} }
return payload return payload

View File

@ -18,7 +18,7 @@ class OnceConversation:
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
""" """
def __init__(self, chat_mode): def __init__(self, chat_mode, user_name: str = None, sys_code: str = None):
self.chat_mode: str = chat_mode self.chat_mode: str = chat_mode
self.messages: List[BaseMessage] = [] self.messages: List[BaseMessage] = []
self.start_date: str = "" self.start_date: str = ""
@ -28,6 +28,8 @@ class OnceConversation:
self.param_value: str = "" self.param_value: str = ""
self.cost: int = 0 self.cost: int = 0
self.tokens: int = 0 self.tokens: int = 0
self.user_name: str = user_name
self.sys_code: str = sys_code
def add_user_message(self, message: str) -> None: def add_user_message(self, message: str) -> None:
"""Add a user message to the store""" """Add a user message to the store"""
@ -113,6 +115,8 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"messages": messages_to_dict(once.messages), "messages": messages_to_dict(once.messages),
"param_type": once.param_type, "param_type": once.param_type,
"param_value": once.param_value, "param_value": once.param_value,
"user_name": once.user_name,
"sys_code": once.sys_code,
} }
@ -121,7 +125,9 @@ def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
def conversation_from_dict(once: dict) -> OnceConversation: def conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation() conversation = OnceConversation(
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
)
conversation.cost = once.get("cost", 0) conversation.cost = once.get("cost", 0)
conversation.chat_mode = once.get("chat_mode", "chat_normal") conversation.chat_mode = once.get("chat_mode", "chat_normal")
conversation.tokens = once.get("tokens", 0) conversation.tokens = once.get("tokens", 0)

View File

@ -29,6 +29,7 @@ class PromptManageEntity(Base):
prompt_name = Column(String(512)) prompt_name = Column(String(512))
content = Column(Text) content = Column(Text)
user_name = Column(String(128)) user_name = Column(String(128))
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime) gmt_created = Column(DateTime)
gmt_modified = Column(DateTime) gmt_modified = Column(DateTime)

147
setup.py
View File

@ -1,5 +1,4 @@
from typing import List, Tuple from typing import List, Tuple, Optional, Callable
import setuptools import setuptools
import platform import platform
import subprocess import subprocess
@ -10,6 +9,7 @@ from urllib.parse import urlparse, quote
import re import re
import shutil import shutil
from setuptools import find_packages from setuptools import find_packages
import functools
with open("README.md", mode="r", encoding="utf-8") as fh: with open("README.md", mode="r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()
@ -34,8 +34,15 @@ def parse_requirements(file_name: str) -> List[str]:
def get_latest_version(package_name: str, index_url: str, default_version: str): def get_latest_version(package_name: str, index_url: str, default_version: str):
python_command = shutil.which("python")
if not python_command:
python_command = shutil.which("python3")
if not python_command:
print("Python command not found.")
return default_version
command = [ command = [
"python", python_command,
"-m", "-m",
"pip", "pip",
"index", "index",
@ -125,6 +132,7 @@ class OSType(Enum):
OTHER = "other" OTHER = "other"
@functools.cache
def get_cpu_avx_support() -> Tuple[OSType, AVXType]: def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
system = platform.system() system = platform.system()
os_type = OSType.OTHER os_type = OSType.OTHER
@ -206,6 +214,57 @@ def get_cuda_version() -> str:
return None return None
def _build_wheels(
pkg_name: str,
pkg_version: str,
base_url: str = None,
base_url_func: Callable[[str, str, str], str] = None,
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
supported_cuda_versions: List[str] = ["11.7", "11.8"],
) -> Optional[str]:
"""
Build the URL for the package wheel file based on the package name, version, and CUDA version.
Args:
pkg_name (str): The name of the package.
pkg_version (str): The version of the package.
base_url (str): The base URL for downloading the package.
base_url_func (Callable): A function to generate the base URL.
pkg_file_func (Callable): build package file function.
function params: pkg_name, pkg_version, cuda_version, py_version, OSType
supported_cuda_versions (List[str]): The list of supported CUDA versions.
Returns:
Optional[str]: The URL for the package wheel file.
"""
os_type, _ = get_cpu_avx_support()
cuda_version = get_cuda_version()
py_version = platform.python_version()
py_version = "cp" + "".join(py_version.split(".")[0:2])
if os_type == OSType.DARWIN or not cuda_version:
return None
if cuda_version not in supported_cuda_versions:
print(
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
)
cuda_version = supported_cuda_versions[-1]
cuda_version = "cu" + cuda_version.replace(".", "")
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
if base_url_func:
base_url = base_url_func(pkg_version, cuda_version, py_version)
if base_url and base_url.endswith("/"):
base_url = base_url[:-1]
if pkg_file_func:
full_pkg_file = pkg_file_func(
pkg_name, pkg_version, cuda_version, py_version, os_type
)
else:
full_pkg_file = f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
if not base_url:
return full_pkg_file
else:
return f"{base_url}/{full_pkg_file}"
def torch_requires( def torch_requires(
torch_version: str = "2.0.1", torch_version: str = "2.0.1",
torchvision_version: str = "0.15.2", torchvision_version: str = "0.15.2",
@ -222,16 +281,20 @@ def torch_requires(
cuda_version = get_cuda_version() cuda_version = get_cuda_version()
if cuda_version: if cuda_version:
supported_versions = ["11.7", "11.8"] supported_versions = ["11.7", "11.8"]
if cuda_version not in supported_versions: # torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
print( # torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
f"PyTorch version {torch_version} supported cuda version: {supported_versions}, replace to {supported_versions[-1]}" torch_url = _build_wheels(
) "torch",
cuda_version = supported_versions[-1] torch_version,
cuda_version = "cu" + cuda_version.replace(".", "") base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
py_version = "cp310" supported_cuda_versions=supported_versions,
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64" )
torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl" torchvision_url = _build_wheels(
torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl" "torchvision",
torch_version,
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
supported_cuda_versions=supported_versions,
)
torch_url_cached = cache_package( torch_url_cached = cache_package(
torch_url, "torch", os_type == OSType.WINDOWS torch_url, "torch", os_type == OSType.WINDOWS
) )
@ -327,6 +390,7 @@ def core_requires():
"xlrd==2.0.1", "xlrd==2.0.1",
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit. # for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
"pympler", "pympler",
"aiofiles",
] ]
if BUILD_FROM_SOURCE: if BUILD_FROM_SOURCE:
setup_spec.extras["framework"].append( setup_spec.extras["framework"].append(
@ -360,6 +424,41 @@ def llama_cpp_requires():
llama_cpp_python_cuda_requires() llama_cpp_python_cuda_requires()
def _build_autoawq_requires() -> Optional[str]:
os_type, _ = get_cpu_avx_support()
if os_type == OSType.DARWIN:
return None
auto_gptq_version = get_latest_version(
"auto-gptq", "https://huggingface.github.io/autogptq-index/whl/cu118/", "0.5.1"
)
# eg. 0.5.1+cu118
auto_gptq_version = auto_gptq_version.split("+")[0]
def pkg_file_func(pkg_name, pkg_version, cuda_version, py_version, os_type):
pkg_name = pkg_name.replace("-", "_")
if os_type == OSType.DARWIN:
return None
os_pkg_name = (
"manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
if os_type == OSType.LINUX
else "win_amd64.whl"
)
return f"{pkg_name}-{pkg_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}"
auto_gptq_url = _build_wheels(
"auto-gptq",
auto_gptq_version,
base_url_func=lambda v, x, y: f"https://huggingface.github.io/autogptq-index/whl/{x}/auto-gptq",
pkg_file_func=pkg_file_func,
supported_cuda_versions=["11.8"],
)
if auto_gptq_url:
print(f"Install auto-gptq from {auto_gptq_url}")
return f"auto-gptq @ {auto_gptq_url}"
else:
"auto-gptq"
def quantization_requires(): def quantization_requires():
pkgs = [] pkgs = []
os_type, _ = get_cpu_avx_support() os_type, _ = get_cpu_avx_support()
@ -379,6 +478,28 @@ def quantization_requires():
print(pkgs) print(pkgs)
# For chatglm2-6b-int4 # For chatglm2-6b-int4
pkgs += ["cpm_kernels"] pkgs += ["cpm_kernels"]
# Since transformers 4.35.0, the GPT-Q/AWQ model can be loaded using AutoModelForCausalLM.
# autoawq requirements:
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
# 2. CUDA Toolkit 11.8 and later.
autoawq_url = _build_wheels(
"autoawq",
"0.1.7",
base_url_func=lambda v, x, y: f"https://github.com/casper-hansen/AutoAWQ/releases/download/v{v}",
supported_cuda_versions=["11.8"],
)
if autoawq_url:
print(f"Install autoawq from {autoawq_url}")
pkgs.append(f"autoawq @ {autoawq_url}")
else:
pkgs.append("autoawq")
auto_gptq_pkg = _build_autoawq_requires()
if auto_gptq_pkg:
pkgs.append(auto_gptq_pkg)
pkgs.append("optimum")
setup_spec.extras["quantization"] = pkgs setup_spec.extras["quantization"] = pkgs