style:format code

format code
This commit is contained in:
aries_ckt 2023-06-27 22:20:21 +08:00
parent 0878f3c5d4
commit 682b1468d1
47 changed files with 438 additions and 290 deletions

View File

@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, set): if isinstance(obj, set):
return list(obj) return list(obj)
elif hasattr(obj, '__dict__'): elif hasattr(obj, "__dict__"):
return obj.__dict__ return obj.__dict__
else: else:
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)

View File

@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load: if not cfg.plugins_auto_load:
print("not auto load_native_plugins") print("not auto load_native_plugins")
return return
def load_from_git(cfg: Config): def load_from_git(cfg: Config):
print("async load_native_plugins") print("async load_native_plugins")
branch_name = cfg.plugins_git_branch branch_name = cfg.plugins_git_branch
@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
url = "https://github.com/csunny/{repo}/archive/{branch}.zip" url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try: try:
session = requests.Session() session = requests.Session()
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name), response = session.get(
headers={'Authorization': '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'}) url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200: if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR) plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) files = glob.glob(
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
for file in files: for file in files:
os.remove(file) os.remove(file)
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S') time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name) print(file_name)
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
t.start() t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them. """Scan the plugins directory for plugins and loads them.

View File

@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
THREE = auto() THREE = auto()
FOUR = auto() FOUR = auto()
class ExampleType(Enum): class ExampleType(Enum):
ONE_SHOT = "one_shot" ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot" FEW_SHOT = "few_shot"

View File

@ -6,7 +6,7 @@ import numpy as np
from matplotlib.font_manager import FontProperties from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar from pyecharts.charts import Bar
from pyecharts import options as opts from pyecharts import options as opts
from test_cls_1 import TestBase,Test1 from test_cls_1 import TestBase, Test1
from test_cls_2 import Test2 from test_cls_2 import Test2
CFG = Config() CFG = Config()
@ -60,21 +60,21 @@ CFG = Config()
# if __name__ == "__main__": # if __name__ == "__main__":
# def __extract_json(s): # def __extract_json(s):
# i = s.index("{") # i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 # count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1): # for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}": # if c == "}":
# count -= 1 # count -= 1
# elif c == "{": # elif c == "{":
# count += 1 # count += 1
# if count == 0: # if count == 0:
# break # break
# assert count == 0 # 检查是否找到最后一个'}' # assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1] # return s[i : j + 1]
# #
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" # ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss)) # print(__extract_json(ss))
if __name__ == "__main__": if __name__ == "__main__":
test1 = Test1() test1 = Test1()

View File

@ -4,9 +4,9 @@ from test_cls_base import TestBase
class Test1(TestBase): class Test1(TestBase):
mode:str = "456" mode: str = "456"
def write(self): def write(self):
self.test_values.append("x") self.test_values.append("x")
self.test_values.append("y") self.test_values.append("y")
self.test_values.append("g") self.test_values.append("g")

View File

@ -3,9 +3,11 @@ from pydantic import BaseModel
from test_cls_base import TestBase from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase): class Test2(TestBase):
test_2_values:List = [] test_2_values: List = []
mode:str = "789" mode: str = "789"
def write(self): def write(self):
self.test_values.append(1) self.test_values.append(1)
self.test_values.append(2) self.test_values.append(2)

View File

@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC): class TestBase(BaseModel, ABC):
test_values: List = [] test_values: List = []
mode:str = "123" mode: str = "123"
def test(self): def test(self):
print(self.__class__.__name__ + ":" ) print(self.__class__.__name__ + ":")
print(self.test_values) print(self.test_values)
print(self.mode) print(self.mode)

View File

@ -39,7 +39,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch() return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config) return get_knowledge_embedding(
self.knowledge_type, self.knowledge_source, self.vector_store_config
)
def similar_search(self, text, topk): def similar_search(self, text, topk):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(
@ -56,3 +58,9 @@ class KnowledgeEmbedding:
CFG.VECTOR_STORE_TYPE, self.vector_store_config CFG.VECTOR_STORE_TYPE, self.vector_store_config
) )
return vector_client.vector_name_exists() return vector_client.vector_name_exists()
def delete_by_ids(self, ids):
vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
vector_client.delete_by_ids(ids=ids)

View File

@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
def clear(self) -> None: def clear(self) -> None:
"""Clear session memory from the local file""" """Clear session memory from the local file"""
def conv_list(self, user_name: str = None) -> None:
def conv_list(self, user_name:str=None) -> None:
"""get user's conversation list""" """get user's conversation list"""
pass pass

View File

@ -14,13 +14,12 @@ from pilot.common.formatting import MyEncoder
default_db_path = os.path.join(os.getcwd(), "message") default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
table_name = 'chat_history' table_name = "chat_history"
CFG = Config() CFG = Config()
class DuckdbHistoryMemory(BaseChatHistoryMemory): class DuckdbHistoryMemory(BaseChatHistoryMemory):
def __init__(self, chat_session_id: str): def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id self.chat_seesion_id = chat_session_id
os.makedirs(default_db_path, exist_ok=True) os.makedirs(default_db_path, exist_ok=True)
@ -28,15 +27,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
self.__init_chat_history_tables() self.__init_chat_history_tables()
def __init_chat_history_tables(self): def __init_chat_history_tables(self):
# 检查表是否存在 # 检查表是否存在
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", result = self.connect.execute(
[table_name]).fetchall() "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result: if not result:
# 如果表不存在,则创建新表 # 如果表不存在,则创建新表
self.connect.execute( self.connect.execute(
"CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)") "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)"
)
def __get_messages_by_conv_uid(self, conv_uid: str): def __get_messages_by_conv_uid(self, conv_uid: str):
cursor = self.connect.cursor() cursor = self.connect.cursor()
@ -58,23 +58,46 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
conversations.append(once_message) conversations.append(once_message)
cursor = self.connect.cursor() cursor = self.connect.cursor()
if context: if context:
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", cursor.execute(
[json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id]) "UPDATE chat_history set messages=? where conv_uid=?",
[
json.dumps(
conversations_to_dict(conversations),
ensure_ascii=False,
indent=4,
),
self.chat_seesion_id,
],
)
else: else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", cursor.execute(
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)]) "INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[
self.chat_seesion_id,
"",
json.dumps(
conversations_to_dict(conversations),
ensure_ascii=False,
indent=4,
),
],
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def clear(self) -> None: def clear(self) -> None:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def delete(self) -> bool: def delete(self) -> bool:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
cursor.commit() cursor.commit()
return True return True
@ -83,7 +106,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
if user_name: if user_name:
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) cursor.execute(
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
)
else: else:
cursor.execute("SELECT * FROM chat_history limit 20") cursor.execute("SELECT * FROM chat_history limit 20")
# 获取查询结果字段名 # 获取查询结果字段名
@ -99,10 +124,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return [] return []
def get_messages(self) -> List[OnceConversation]:
def get_messages(self)-> List[OnceConversation]:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.execute(
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
)
context = cursor.fetchone() context = cursor.fetchone()
if context: if context:
return json.loads(context[0]) return json.loads(context[0])

View File

@ -19,7 +19,6 @@ CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory): class MemHistoryMemory(BaseChatHistoryMemory):
histroies_map = FixedSizeDict(100) histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str): def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id self.chat_seesion_id = chat_session_id
self.histroies_map.update({chat_session_id: []}) self.histroies_map.update({chat_session_id: []})

View File

@ -3,8 +3,8 @@ import hashlib
from typing import Any, Dict from typing import Any, Dict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class Cache(ABC):
class Cache(ABC):
def create(self, key: str) -> bool: def create(self, key: str) -> bool:
pass pass

View File

@ -3,15 +3,15 @@ import diskcache
import platformdirs import platformdirs
from pilot.model.cache import Cache from pilot.model.cache import Cache
class DiskCache(Cache): class DiskCache(Cache):
"""DiskCache is a cache that uses diskcache lib. """DiskCache is a cache that uses diskcache lib.
https://github.com/grantjenks/python-diskcache https://github.com/grantjenks/python-diskcache
""" """
def __init__(self, llm_name: str): def __init__(self, llm_name: str):
self._diskcache = diskcache.Cache( self._diskcache = diskcache.Cache(
os.path.join( os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
)
) )
def __getitem__(self, key: str) -> str: def __getitem__(self, key: str) -> str:

View File

@ -9,6 +9,7 @@ try:
except ImportError: except ImportError:
pass pass
class GPTCache(Cache): class GPTCache(Cache):
""" """
@ -24,7 +25,7 @@ class GPTCache(Cache):
data_dir=os.path.join( data_dir=os.path.join(
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache" platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
), ),
cache_obj=_cache cache_obj=_cache,
) )
else: else:
_cache = cache _cache = cache

View File

@ -1,8 +1,8 @@
from typing import Dict, Any from typing import Dict, Any
from pilot.model.cache import Cache from pilot.model.cache import Cache
class InMemoryCache(Cache):
class InMemoryCache(Cache):
def __init__(self) -> None: def __init__(self) -> None:
"Initialize that stores things in memory." "Initialize that stores things in memory."
self._cache: Dict[str, Any] = {} self._cache: Dict[str, Any] = {}
@ -21,4 +21,3 @@ class InMemoryCache(Cache):
def __contains__(self, key: str) -> bool: def __contains__(self, key: str) -> bool:
return self._cache.get(key, None) is not None return self._cache.get(key, None) is not None

View File

@ -12,16 +12,21 @@ from fastapi.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from typing import List from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.server.api_v1.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo,
)
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import (LOGDIR) from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.scene.base_message import (BaseMessage) from pilot.scene.base_message import BaseMessage
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
@ -40,19 +45,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
def __get_conv_user_message(conversations: dict): def __get_conv_user_message(conversations: dict):
messages = conversations['messages'] messages = conversations["messages"]
for item in messages: for item in messages:
if item['type'] == "human": if item["type"] == "human":
return item['data']['content'] return item["data"]["content"]
return "" return ""
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo]) @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None): async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息 # 设置CORS头部信息
response.headers['Access-Control-Allow-Origin'] = '*' response.headers["Access-Control-Allow-Origin"] = "*"
response.headers['Access-Control-Allow-Methods'] = 'GET' response.headers["Access-Control-Allow-Methods"] = "GET"
response.headers['Access-Control-Request-Headers'] = 'content-type' response.headers["Access-Control-Request-Headers"] = "content-type"
dialogues: List = [] dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id) datas = DuckdbHistoryMemory.conv_list(user_id)
@ -63,26 +68,44 @@ async def dialogue_list(response: Response, user_id: str = None):
conversations = json.loads(messages) conversations = json.loads(messages)
first_conv: OnceConversation = conversations[0] first_conv: OnceConversation = conversations[0]
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv), conv_vo: ConversationVo = ConversationVo(
chat_mode=first_conv['chat_mode']) conv_uid=conv_uid,
user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv["chat_mode"],
)
dialogues.append(conv_vo) dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues) return Result[ConversationVo].succ(dialogues)
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]]) @router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes(): async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = [] scene_vos: List[ChatSceneVo] = []
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution] new_modes: List[ChatScene] = [
ChatScene.ChatDb,
ChatScene.ChatData,
ChatScene.ChatDashboard,
ChatScene.ChatKnowledge,
ChatScene.ChatExecution,
]
for scene in new_modes: for scene in new_modes:
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]: if not scene.value in [
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param") ChatScene.ChatNormal.value,
ChatScene.InnerChatDBSummary.value,
]:
scene_vo = ChatSceneVo(
chat_scene=scene.value,
scene_name=scene.name,
param_title="Selection Param",
)
scene_vos.append(scene_vo) scene_vos.append(scene_vo)
return Result.succ(scene_vos) return Result.succ(scene_vos)
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None): async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1() unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@ -90,7 +113,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
def get_db_list(): def get_db_list():
db = CFG.local_db db = CFG.local_db
dbs = db.get_database_list() dbs = db.get_database_list()
params:dict = {} params: dict = {}
for name in dbs: for name in dbs:
params.update({name: name}) params.update({name: name})
return params return params
@ -108,7 +131,7 @@ def knowledge_list():
return knowledge_service.get_knowledge_space(request) return knowledge_service.get_knowledge_space(request)
@router.post('/v1/chat/mode/params/list', response_model=Result[dict]) @router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value): async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatDb.value == chat_mode: if ChatScene.ChatDb.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
@ -124,14 +147,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
return Result.succ(None) return Result.succ(None)
@router.post('/v1/chat/dialogue/delete') @router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str): async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_mem.delete() history_mem.delete()
return Result.succ(None) return Result.succ(None)
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo]) @router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str): async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}") print(f"dialogue_history_messages:{con_uid}")
message_vos: List[MessageVo] = [] message_vos: List[MessageVo] = []
@ -140,17 +163,21 @@ async def dialogue_history_messages(con_uid: str):
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
for once in history_messages: for once in history_messages:
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']] once_message_vos = [
message2Vo(element, once["chat_order"]) for element in once["messages"]
]
message_vos.extend(once_message_vos) message_vos.extend(once_message_vos)
return Result.succ(message_vos) return Result.succ(message_vos)
@router.post('/v1/chat/completions') @router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()): async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not ChatScene.is_valid_mode(dialogue.chat_mode): if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) raise StopAsyncIteration(
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
)
chat_param = { chat_param = {
"chat_session_id": dialogue.conv_uid, "chat_session_id": dialogue.conv_uid,
@ -188,7 +215,9 @@ def stream_generator(chat):
model_response = chat.stream_call() model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
chat.current_message.add_ai_message(msg) chat.current_message.add_ai_message(msg)
yield msg yield msg
# chat.current_message.add_ai_message(msg) # chat.current_message.add_ai_message(msg)
@ -206,7 +235,9 @@ def stream_generator(chat):
def message2Vo(message: dict, order) -> MessageVo: def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
return MessageVo(role=message['type'], context=message['data']['content'], order=order) return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
)
def non_stream_response(chat): def non_stream_response(chat):
@ -214,38 +245,38 @@ def non_stream_response(chat):
return chat.nostream_call() return chat.nostream_call()
@router.get('/v1/db/types', response_model=Result[str]) @router.get("/v1/db/types", response_model=Result[str])
async def db_types(): async def db_types():
return Result.succ(["mysql", "duckdb"]) return Result.succ(["mysql", "duckdb"])
@router.get('/v1/db/list', response_model=Result[str]) @router.get("/v1/db/list", response_model=Result[str])
async def db_list(): async def db_list():
db = CFG.local_db db = CFG.local_db
dbs = db.get_database_list() dbs = db.get_database_list()
return Result.succ(dbs) return Result.succ(dbs)
@router.get('/v1/knowledge/list') @router.get("/v1/knowledge/list")
async def knowledge_list(): async def knowledge_list():
return ["test1", "test2"] return ["test1", "test2"]
@router.post('/v1/knowledge/add') @router.post("/v1/knowledge/add")
async def knowledge_add(): async def knowledge_add():
return ["test1", "test2"] return ["test1", "test2"]
@router.post('/v1/knowledge/delete') @router.post("/v1/knowledge/delete")
async def knowledge_delete(): async def knowledge_delete():
return ["test1", "test2"] return ["test1", "test2"]
@router.get('/v1/knowledge/types') @router.get("/v1/knowledge/types")
async def knowledge_types(): async def knowledge_types():
return ["test1", "test2"] return ["test1", "test2"]
@router.get('/v1/knowledge/detail') @router.get("/v1/knowledge/detail")
async def knowledge_detail(): async def knowledge_detail():
return ["test1", "test2"] return ["test1", "test2"]

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any from typing import TypeVar, Union, List, Generic, Any
T = TypeVar('T') T = TypeVar("T")
class Result(Generic[T], BaseModel): class Result(Generic[T], BaseModel):
@ -28,10 +28,12 @@ class ChatSceneVo(BaseModel):
scene_name: str = Field(..., description="chat_scene name show for user") scene_name: str = Field(..., description="chat_scene name show for user")
param_title: str = Field(..., description="chat_scene required parameter title") param_title: str = Field(..., description="chat_scene required parameter title")
class ConversationVo(BaseModel): class ConversationVo(BaseModel):
""" """
dialogue_uid dialogue_uid
""" """
conv_uid: str = Field(..., description="dialogue uid") conv_uid: str = Field(..., description="dialogue uid")
""" """
user input user input
@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
select_param: str = None select_param: str = None
class MessageVo(BaseModel): class MessageVo(BaseModel):
""" """
role that sends out the current message role that sends out the current message
""" """
role: str role: str
""" """
current message current message
@ -70,4 +72,3 @@ class MessageVo(BaseModel):
time the current message was sent time the current message was sent
""" """
time_stamp: Any = None time_stamp: Any = None

View File

@ -10,8 +10,10 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
Base = declarative_base() Base = declarative_base()
class DocumentChunkEntity(Base): class DocumentChunkEntity(Base):
__tablename__ = 'document_chunk' __tablename__ = "document_chunk"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
document_id = Column(Integer) document_id = Column(Integer)
doc_name = Column(String(100)) doc_name = Column(String(100))
@ -29,11 +31,12 @@ class DocumentChunkDao:
def __init__(self): def __init__(self):
database = "knowledge_management" database = "knowledge_management"
self.db_engine = create_engine( self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True) echo=True,
)
self.Session = sessionmaker(bind=self.db_engine) self.Session = sessionmaker(bind=self.db_engine)
def create_documents_chunks(self, documents:List): def create_documents_chunks(self, documents: List):
session = self.Session() session = self.Session()
docs = [ docs = [
DocumentChunkEntity( DocumentChunkEntity(
@ -43,29 +46,40 @@ class DocumentChunkDao:
content=document.content or "", content=document.content or "",
meta_info=document.meta_info or "", meta_info=document.meta_info or "",
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
for document in documents] for document in documents
]
session.add_all(docs) session.add_all(docs)
session.commit() session.commit()
session.close() session.close()
def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20): def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session() session = self.Session()
document_chunks = session.query(DocumentChunkEntity) document_chunks = session.query(DocumentChunkEntity)
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None: if query.document_id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id) document_chunks = document_chunks.filter(
DocumentChunkEntity.document_id == query.document_id
)
if query.doc_type is not None: if query.doc_type is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type) document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_type == query.doc_type
)
if query.doc_name is not None: if query.doc_name is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name) document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_name == query.doc_name
)
if query.meta_info is not None: if query.meta_info is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info) document_chunks = document_chunks.filter(
DocumentChunkEntity.meta_info == query.meta_info
)
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc()) document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size) document_chunks = document_chunks.offset((page - 1) * page_size).limit(
page_size
)
result = document_chunks.all() result = document_chunks.all()
return result return result

View File

@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import ( from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeQueryRequest, KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, KnowledgeQueryResponse,
KnowledgeDocumentRequest,
DocumentSyncRequest,
ChunkQueryRequest,
DocumentQueryRequest,
) )
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
def document_list(space_name: str, query_request: DocumentQueryRequest): def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}") print(f"/document/list params: {space_name}, {query_request}")
try: try:
return Result.succ(knowledge_space_service.get_knowledge_documents( return Result.succ(
space_name, knowledge_space_service.get_knowledge_documents(space_name, query_request)
query_request )
))
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document list error {e}") return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/upload") @router.post("/knowledge/{space_name}/document/upload")
def document_sync(space_name: str, file: UploadFile = File(...)): async def document_sync(space_name: str, file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}") print(f"/document/upload params: {space_name}")
try: try:
with NamedTemporaryFile(delete=False) as tmp: with NamedTemporaryFile(delete=False) as tmp:
@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
knowledge_space_service.sync_knowledge_document( knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids space_name=space_name, doc_ids=request.doc_ids
) )
Result.succ([]) return Result.succ([])
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}") return Result.faild(code="E000X", msg=f"document sync error {e}")
@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
def document_list(space_name: str, query_request: ChunkQueryRequest): def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}") print(f"/document/list params: {space_name}, {query_request}")
try: try:
return Result.succ(knowledge_space_service.get_document_chunks( return Result.succ(knowledge_space_service.get_document_chunks(query_request))
query_request
))
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document chunk list error {e}") return Result.faild(code="E000X", msg=f"document chunk list error {e}")

View File

@ -12,7 +12,7 @@ Base = declarative_base()
class KnowledgeDocumentEntity(Base): class KnowledgeDocumentEntity(Base):
__tablename__ = 'knowledge_document' __tablename__ = "knowledge_document"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
doc_name = Column(String(100)) doc_name = Column(String(100))
doc_type = Column(String(100)) doc_type = Column(String(100))
@ -21,23 +21,25 @@ class KnowledgeDocumentEntity(Base):
status = Column(String(100)) status = Column(String(100))
last_sync = Column(String(100)) last_sync = Column(String(100))
content = Column(Text) content = Column(Text)
result = Column(Text)
vector_ids = Column(Text) vector_ids = Column(Text)
gmt_created = Column(DateTime) gmt_created = Column(DateTime)
gmt_modified = Column(DateTime) gmt_modified = Column(DateTime)
def __repr__(self): def __repr__(self):
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao: class KnowledgeDocumentDao:
def __init__(self): def __init__(self):
database = "knowledge_management" database = "knowledge_management"
self.db_engine = create_engine( self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True) echo=True,
)
self.Session = sessionmaker(bind=self.db_engine) self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_document(self, document:KnowledgeDocumentEntity): def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.Session()
knowledge_document = KnowledgeDocumentEntity( knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name, doc_name=document.doc_name,
@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
status=document.status, status=document.status,
last_sync=document.last_sync, last_sync=document.last_sync,
content=document.content or "", content=document.content or "",
result=document.result or "",
vector_ids=document.vector_ids, vector_ids=document.vector_ids,
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
session.add(knowledge_document) session.add(knowledge_document)
session.commit() session.commit()
@ -60,28 +63,42 @@ class KnowledgeDocumentDao:
session = self.Session() session = self.Session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id == query.id
)
if query.doc_name is not None: if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_name == query.doc_name
)
if query.doc_type is not None: if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_type == query.doc_type
)
if query.space is not None: if query.space is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.space == query.space
)
if query.status is not None: if query.status is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.status == query.status
)
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc()) knowledge_documents = knowledge_documents.order_by(
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size) KnowledgeDocumentEntity.id.desc()
)
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
page_size
)
result = knowledge_documents.all() result = knowledge_documents.all()
return result return result
def update_knowledge_document(self, document:KnowledgeDocumentEntity): def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.Session()
updated_space = session.merge(document) updated_space = session.merge(document)
session.commit() session.commit()
return updated_space.id return updated_space.id
def delete_knowledge_document(self, document_id:int): def delete_knowledge_document(self, document_id: int):
cursor = self.conn.cursor() cursor = self.conn.cursor()
query = "DELETE FROM knowledge_document WHERE id = %s" query = "DELETE FROM knowledge_document WHERE id = %s"
cursor.execute(query, (document_id,)) cursor.execute(query, (document_id,))

View File

@ -4,17 +4,24 @@ from datetime import datetime
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.logs import logger from pilot.logs import logger
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao from pilot.openapi.knowledge.document_chunk_dao import (
DocumentChunkEntity,
DocumentChunkDao,
)
from pilot.openapi.knowledge.knowledge_document_dao import ( from pilot.openapi.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao, KnowledgeDocumentDao,
KnowledgeDocumentEntity, KnowledgeDocumentEntity,
) )
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity from pilot.openapi.knowledge.knowledge_space_dao import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
from pilot.openapi.knowledge.request.knowledge_request import ( from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest, KnowledgeSpaceRequest,
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest, KnowledgeDocumentRequest,
DocumentQueryRequest,
ChunkQueryRequest,
) )
from enum import Enum from enum import Enum
@ -23,7 +30,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao() knowledge_document_dao = KnowledgeDocumentDao()
document_chunk_dao = DocumentChunkDao() document_chunk_dao = DocumentChunkDao()
CFG=Config() CFG = Config()
class SyncStatus(Enum): class SyncStatus(Enum):
@ -53,10 +60,7 @@ class KnowledgeService:
"""create knowledge document""" """create knowledge document"""
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest): def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
doc_name=request.doc_name,
space=space
)
documents = knowledge_document_dao.get_knowledge_documents(query) documents = knowledge_document_dao.get_knowledge_documents(query)
if len(documents) > 0: if len(documents) > 0:
raise Exception(f"document name:{request.doc_name} have already named") raise Exception(f"document name:{request.doc_name} have already named")
@ -74,26 +78,27 @@ class KnowledgeService:
"""get knowledge space""" """get knowledge space"""
def get_knowledge_space(self, request:KnowledgeSpaceRequest): def get_knowledge_space(self, request: KnowledgeSpaceRequest):
query = KnowledgeSpaceEntity( query = KnowledgeSpaceEntity(
name=request.name, name=request.name, vector_type=request.vector_type, owner=request.owner
vector_type=request.vector_type,
owner=request.owner
) )
return knowledge_space_dao.get_knowledge_space(query) return knowledge_space_dao.get_knowledge_space(query)
"""get knowledge get_knowledge_documents""" """get knowledge get_knowledge_documents"""
def get_knowledge_documents(self, space, request:DocumentQueryRequest): def get_knowledge_documents(self, space, request: DocumentQueryRequest):
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(
doc_name=request.doc_name, doc_name=request.doc_name,
doc_type=request.doc_type, doc_type=request.doc_type,
space=space, space=space,
status=request.status, status=request.status,
) )
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size) return knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
"""sync knowledge document chunk into vector store""" """sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids): def sync_knowledge_document(self, space_name, doc_ids):
for doc_id in doc_ids: for doc_id in doc_ids:
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(
@ -101,12 +106,14 @@ class KnowledgeService:
space=space_name, space=space_name,
) )
doc = knowledge_document_dao.get_knowledge_documents(query)[0] doc = knowledge_document_dao.get_knowledge_documents(query)[0]
client = KnowledgeEmbedding(knowledge_source=doc.content, client = KnowledgeEmbedding(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(), knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={ vector_store_config={
"vector_store_name": space_name, "vector_store_name": space_name,
}) },
)
chunk_docs = client.read() chunk_docs = client.read()
# update document status # update document status
doc.status = SyncStatus.RUNNING.name doc.status = SyncStatus.RUNNING.name
@ -114,9 +121,12 @@ class KnowledgeService:
doc.gmt_modified = datetime.now() doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc) knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings # async doc embeddings
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc)) thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
)
thread.start() thread.start()
#save chunk details logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [ chunk_entities = [
DocumentChunkEntity( DocumentChunkEntity(
doc_name=doc.doc_name, doc_name=doc.doc_name,
@ -125,9 +135,10 @@ class KnowledgeService:
content=chunk_doc.page_content, content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata), meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
for chunk_doc in chunk_docs] for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities) document_chunk_dao.create_documents_chunks(chunk_entities)
return True return True
@ -145,26 +156,30 @@ class KnowledgeService:
return knowledge_space_dao.delete_knowledge_space(space_id) return knowledge_space_dao.delete_knowledge_space(space_id)
"""get document chunks""" """get document chunks"""
def get_document_chunks(self, request:ChunkQueryRequest):
def get_document_chunks(self, request: ChunkQueryRequest):
query = DocumentChunkEntity( query = DocumentChunkEntity(
id=request.id, id=request.id,
document_id=request.document_id, document_id=request.document_id,
doc_name=request.doc_name, doc_name=request.doc_name,
doc_type=request.doc_type doc_type=request.doc_type,
)
return document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size
) )
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
def async_doc_embedding(self, client, chunk_docs, doc): def async_doc_embedding(self, client, chunk_docs, doc):
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}") logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try: try:
vector_ids = client.knowledge_embedding_batch(chunk_docs) vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name doc.status = SyncStatus.FINISHED.name
doc.content = "embedding success" doc.result = "document embedding success"
doc.vector_ids = ",".join(vector_ids) doc.vector_ids = ",".join(vector_ids)
logger.info(f"async document embedding, success:{doc.doc_name}")
except Exception as e: except Exception as e:
doc.status = SyncStatus.FAILED.name doc.status = SyncStatus.FAILED.name
doc.content = str(e) doc.result = "document embedding failed" + str(e)
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc) return knowledge_document_dao.update_knowledge_document(doc)

View File

@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
CFG = Config() CFG = Config()
Base = declarative_base() Base = declarative_base()
class KnowledgeSpaceEntity(Base): class KnowledgeSpaceEntity(Base):
__tablename__ = 'knowledge_space' __tablename__ = "knowledge_space"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String(100)) name = Column(String(100))
vector_type = Column(String(100)) vector_type = Column(String(100))
@ -27,10 +29,13 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao: class KnowledgeSpaceDao:
def __init__(self): def __init__(self):
database = "knowledge_management" database = "knowledge_management"
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True) self.db_engine = create_engine(
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine) self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_space(self, space:KnowledgeSpaceRequest): def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session() session = self.Session()
knowledge_space = KnowledgeSpaceEntity( knowledge_space = KnowledgeSpaceEntity(
name=space.name, name=space.name,
@ -38,43 +43,61 @@ class KnowledgeSpaceDao:
desc=space.desc, desc=space.desc,
owner=space.owner, owner=space.owner,
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
session.add(knowledge_space) session.add(knowledge_space)
session.commit() session.commit()
session.close() session.close()
def get_knowledge_space(self, query:KnowledgeSpaceEntity): def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session() session = self.Session()
knowledge_spaces = session.query(KnowledgeSpaceEntity) knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None: if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.id == query.id
)
if query.name is not None: if query.name is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.name == query.name
)
if query.vector_type is not None: if query.vector_type is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.vector_type == query.vector_type
)
if query.desc is not None: if query.desc is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.desc == query.desc
)
if query.owner is not None: if query.owner is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.owner == query.owner
)
if query.gmt_created is not None: if query.gmt_created is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_created == query.gmt_created
)
if query.gmt_modified is not None: if query.gmt_modified is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
)
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc()) knowledge_spaces = knowledge_spaces.order_by(
KnowledgeSpaceEntity.gmt_created.desc()
)
result = knowledge_spaces.all() result = knowledge_spaces.all()
return result return result
def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity): def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
cursor = self.conn.cursor() cursor = self.conn.cursor()
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s" query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id)) cursor.execute(
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
)
self.conn.commit() self.conn.commit()
cursor.close() cursor.close()
def delete_knowledge_space(self, space_id:int): def delete_knowledge_space(self, space_id: int):
cursor = self.conn.cursor() cursor = self.conn.cursor()
query = "DELETE FROM knowledge_space WHERE id = %s" query = "DELETE FROM knowledge_space WHERE id = %s"
cursor.execute(query, (space_id,)) cursor.execute(query, (space_id,))

View File

@ -30,17 +30,19 @@ class KnowledgeDocumentRequest(BaseModel):
"""doc_type: doc type""" """doc_type: doc type"""
doc_type: str doc_type: str
"""content: content""" """content: content"""
content: str content: str = None
"""text_chunk_size: text_chunk_size""" """text_chunk_size: text_chunk_size"""
# text_chunk_size: int # text_chunk_size: int
class DocumentQueryRequest(BaseModel): class DocumentQueryRequest(BaseModel):
"""doc_name: doc path""" """doc_name: doc path"""
doc_name: str = None doc_name: str = None
"""doc_type: doc type""" """doc_type: doc type"""
doc_type: str= None doc_type: str = None
"""status: status""" """status: status"""
status: str= None status: str = None
"""page: page""" """page: page"""
page: int = 1 page: int = 1
"""page_size: page size""" """page_size: page size"""
@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
class DocumentSyncRequest(BaseModel): class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids""" """doc_ids: doc ids"""
doc_ids: List doc_ids: List
class ChunkQueryRequest(BaseModel): class ChunkQueryRequest(BaseModel):
"""id: id""" """id: id"""
id: int = None id: int = None
"""document_id: doc id""" """document_id: doc id"""
document_id: int = None document_id: int = None

View File

@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pilot.common.schema import ExampleType from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC): class ExampleSelector(BaseModel, ABC):
examples: List[List] examples: List[List]
use_example: bool = False use_example: bool = False

View File

@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector from pilot.prompts.example_base import ExampleSelector
def jinja2_formatter(template: str, **kwargs: Any) -> str: def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2.""" """Format a template using jinja2."""
try: try:

View File

@ -1,11 +1,13 @@
from enum import Enum from enum import Enum
class Scene: class Scene:
def __init__(self, code, describe, is_inner): def __init__(self, code, describe, is_inner):
self.code = code self.code = code
self.describe = describe self.describe = describe
self.is_inner = is_inner self.is_inner = is_inner
class ChatScene(Enum): class ChatScene(Enum):
ChatWithDbExecute = "chat_with_db_execute" ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa" ChatWithDbQA = "chat_with_db_qa"
@ -19,9 +21,8 @@ class ChatScene(Enum):
ChatDashboard = "chat_dashboard" ChatDashboard = "chat_dashboard"
ChatKnowledge = "chat_knowledge" ChatKnowledge = "chat_knowledge"
ChatDb = "chat_db" ChatDb = "chat_db"
ChatData= "chat_data" ChatData = "chat_data"
@staticmethod @staticmethod
def is_valid_mode(mode): def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene) return any(mode == item.value for item in ChatScene)

View File

@ -104,7 +104,9 @@ class BaseChat(ABC):
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
# TODO # TODO
self.current_message.tokens = 0 self.current_message.tokens = 0
current_prompt = None current_prompt = None
@ -200,8 +202,11 @@ class BaseChat(ABC):
# }""" # }"""
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_action(prompt_define_response) result = self.do_action(prompt_define_response)

View File

@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
generate_htm_table, generate_htm_table,
) )
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
)
CFG = Config() CFG = Config()
@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
report_name: str report_name: str
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, db_name, user_input, report_name):
self, chat_session_id, db_name, user_input, report_name
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbExecute, chat_mode=ChatScene.ChatWithDbExecute,
@ -51,7 +52,7 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input, "input": self.current_user_input,
"dialect": self.database.dialect, "dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect), "table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": "" #TODO "supported_chat_type": "" # TODO
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
} }
return input_values return input_values
@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
# TODO 修复流程 # TODO 修复流程
print(str(e)) print(str(e))
chart_datas.append(chart_data) chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id report_data.conv_uid = self.chat_session_id
@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
report_data.charts = chart_datas report_data.charts = chart_datas
return report_data return report_data

View File

@ -12,11 +12,7 @@ class ChartData(BaseModel):
class ReportData(BaseModel): class ReportData(BaseModel):
conv_uid:str conv_uid: str
template_name:str template_name: str
template_introduce:str template_introduce: str
charts: List[ChartData] charts: List[ChartData]

View File

@ -10,9 +10,9 @@ from pilot.configs.model_config import LOGDIR
class ChartItem(NamedTuple): class ChartItem(NamedTuple):
sql: str sql: str
title:str title: str
thoughts: str thoughts: str
showcase:str showcase: str
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log") logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
response = json.loads(clean_str) response = json.loads(clean_str)
chart_items = List[ChartItem] chart_items = List[ChartItem]
for item in response: for item in response:
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"])) chart_items.append(
ChartItem(
item["sql"], item["title"], item["thoughts"], item["showcase"]
)
)
return chart_items return chart_items
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:

View File

@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
Ensure the response is correct json and can be parsed by Python json.loads Ensure the response is correct json and can be parsed by Python json.loads
""" """
RESPONSE_FORMAT = [{ RESPONSE_FORMAT = [
{
"sql": "data analysis SQL", "sql": "data analysis SQL",
"title": "Data Analysis Title", "title": "Data Analysis Title",
"showcase": "What type of charts to show", "showcase": "What type of charts to show",
"thoughts": "Current thinking and value of data analysis" "thoughts": "Current thinking and value of data analysis",
}] }
]
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, db_name, user_input):
self, chat_session_id, db_name, user_input
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbExecute, chat_mode=ChatScene.ChatWithDbExecute,

View File

@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, db_name, user_input):
self, chat_session_id, db_name, user_input
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbQA, chat_mode=ChatScene.ChatWithDbQA,
@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
"table_info": table_info, "table_info": table_info,
} }
return input_values return input_values

View File

@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default ## Two examples are defined by default
EXAMPLES = [ EXAMPLES = [
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}], [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}] [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}],
] ]
example = ExampleSelector(examples=EXAMPLES, use_example=True) example = ExampleSelector(examples=EXAMPLES, use_example=True)

View File

@ -36,7 +36,6 @@ RESPONSE_FORMAT = {
} }
EXAMPLE_TYPE = ExampleType.ONE_SHOT EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output ### Whether the model service is streaming output
@ -49,8 +48,10 @@ prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT, stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT), output_parser=PluginChatOutputParser(
example=example sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
example=example,
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, user_input, knowledge_name):
self, chat_session_id, user_input, knowledge_name
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatNewKnowledge, chat_mode=ChatScene.ChatNewKnowledge,
@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value return ChatScene.ChatNewKnowledge.value

View File

@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
) )
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatDefaultKnowledge.value return ChatScene.ChatDefaultKnowledge.value

View File

@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
} }
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value return ChatScene.InnerChatDBSummary.value

View File

@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value return ChatScene.ChatUrlKnowledge.value

View File

@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
) )
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value return ChatScene.ChatKnowledge.value

View File

@ -85,7 +85,6 @@ class OnceConversation:
self.messages.clear() self.messages.clear()
self.session_id = None self.session_id = None
def get_user_message(self): def get_user_message(self):
for once in self.messages: for once in self.messages:
if isinstance(once, HumanMessage): if isinstance(once, HumanMessage):

View File

@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
load_dotenv(verbose=True, override=True) load_dotenv(verbose=True, override=True)
del load_dotenv del load_dotenv

View File

@ -27,7 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config() CFG = Config()
class ModelWorker: class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1): def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"): if model_path.endswith("/"):
@ -106,6 +105,7 @@ worker = ModelWorker(
app = FastAPI() app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router) app.include_router(router)
origins = [ origins = [
@ -119,9 +119,10 @@ app.add_middleware(
allow_origins=origins, allow_origins=origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"] allow_headers=["*"],
) )
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str
temperature: float temperature: float

View File

@ -107,12 +107,16 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue, add_knowledge_base_dialogue,
] ]
def swagger_monkey_patch(*args, **kwargs): def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html( return get_swagger_ui_html(
*args, **kwargs, *args,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', **kwargs,
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
) )
applications.get_swagger_ui_html = swagger_monkey_patch applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI() app = FastAPI()
@ -360,14 +364,18 @@ def http_bot(
response = chat.stream_call() response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
state.messages[-1][-1] =msg chunk, chat.skip_echo_len
)
state.messages[-1][-1] = msg
chat.current_message.add_ai_message(msg) chat.current_message.add_ai_message(msg)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message) chat.memory.append(chat.current_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """ state.messages[-1][
-1
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@ -693,8 +701,12 @@ def signal_handler(sig, frame):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) parser.add_argument(
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode') "--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument(
"-new", "--new", action="store_true", help="enable new http mode"
)
# old version server config # old version server config
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
@ -702,27 +714,24 @@ if __name__ == "__main__":
parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true") parser.add_argument("--share", default=False, action="store_true")
# init server config # init server config
args = parser.parse_args() args = parser.parse_args()
server_init(args) server_init(args)
if args.new: if args.new:
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000) uvicorn.run(app, host="0.0.0.0", port=5000)
else: else:
### Compatibility mode starts the old version server by default ### Compatibility mode starts the old version server by default
demo = build_webdemo() demo = build_webdemo()
demo.queue( demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False concurrency_count=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch( ).launch(
server_name=args.host, server_name=args.host,
server_port=args.port, server_port=args.port,
share=args.share, share=args.share,
max_threads=200, max_threads=200,
) )

View File

@ -38,7 +38,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch() return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config) return get_knowledge_embedding(
self.file_type.upper(), self.file_path, self.vector_store_config
)
def similar_search(self, text, topk): def similar_search(self, text, topk):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(