diff --git a/pilot/common/formatting.py b/pilot/common/formatting.py index c2db4126a..6bf10c1b2 100644 --- a/pilot/common/formatting.py +++ b/pilot/common/formatting.py @@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, set): return list(obj) - elif hasattr(obj, '__dict__'): + elif hasattr(obj, "__dict__"): return obj.__dict__ else: - return json.JSONEncoder.default(self, obj) \ No newline at end of file + return json.JSONEncoder.default(self, obj) diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py index 832144d22..e22224399 100644 --- a/pilot/common/plugins.py +++ b/pilot/common/plugins.py @@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config): if not cfg.plugins_auto_load: print("not auto load_native_plugins") return + def load_from_git(cfg: Config): print("async load_native_plugins") branch_name = cfg.plugins_git_branch @@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config): url = "https://github.com/csunny/{repo}/archive/{branch}.zip" try: session = requests.Session() - response = session.get(url.format(repo=native_plugin_repo, branch=branch_name), - headers={'Authorization': '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'}) + response = session.get( + url.format(repo=native_plugin_repo, branch=branch_name), + headers={"Authorization": "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"}, + ) if response.status_code == 200: plugins_path_path = Path(PLUGINS_DIR) - files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) + files = glob.glob( + os.path.join(plugins_path_path, f"{native_plugin_repo}*") + ) for file in files: os.remove(file) now = datetime.datetime.now() - time_str = now.strftime('%Y%m%d%H%M%S') + time_str = now.strftime("%Y%m%d%H%M%S") file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" print(file_name) with open(file_name, "wb") as f: @@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config): t.start() - def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: """Scan the plugins directory for plugins and loads them. diff --git a/pilot/common/schema.py b/pilot/common/schema.py index dc7a6c7cc..ab882e30b 100644 --- a/pilot/common/schema.py +++ b/pilot/common/schema.py @@ -8,6 +8,7 @@ class SeparatorStyle(Enum): THREE = auto() FOUR = auto() + class ExampleType(Enum): ONE_SHOT = "one_shot" FEW_SHOT = "few_shot" diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 94ff19e21..7e259f9fc 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -90,7 +90,7 @@ class Config(metaclass=Singleton): ### The associated configuration parameters of the plug-in control the loading and use of the plug-in self.plugins: List[AutoGPTPluginTemplate] = [] self.plugins_openai = [] - self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" + self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 5ad5be08f..31b060ef1 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -6,7 +6,7 @@ import numpy as np from matplotlib.font_manager import FontProperties from pyecharts.charts import Bar from pyecharts import options as opts -from test_cls_1 import TestBase,Test1 +from test_cls_1 import TestBase, Test1 from test_cls_2 import Test2 CFG = Config() @@ -60,21 +60,21 @@ CFG = Config() # if __name__ == "__main__": - # def __extract_json(s): - # i = s.index("{") - # count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - # for j, c in enumerate(s[i + 1 :], start=i + 1): - # if c == "}": - # count -= 1 - # elif c == "{": - # count += 1 - # if count == 0: - # break - # assert count == 0 # 检查是否找到最后一个'}' - # return s[i : j + 1] - # - # ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" - # print(__extract_json(ss)) +# def __extract_json(s): +# i = s.index("{") +# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 +# for j, c in enumerate(s[i + 1 :], start=i + 1): +# if c == "}": +# count -= 1 +# elif c == "{": +# count += 1 +# if count == 0: +# break +# assert count == 0 # 检查是否找到最后一个'}' +# return s[i : j + 1] +# +# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" +# print(__extract_json(ss)) if __name__ == "__main__": test1 = Test1() @@ -83,4 +83,4 @@ if __name__ == "__main__": test1.test() test2.write() test1.test() - test2.test() \ No newline at end of file + test2.test() diff --git a/pilot/connections/rdbms/py_study/test_cls_1.py b/pilot/connections/rdbms/py_study/test_cls_1.py index c7d26a674..1b91b5601 100644 --- a/pilot/connections/rdbms/py_study/test_cls_1.py +++ b/pilot/connections/rdbms/py_study/test_cls_1.py @@ -4,9 +4,9 @@ from test_cls_base import TestBase class Test1(TestBase): - mode:str = "456" + mode: str = "456" + def write(self): self.test_values.append("x") self.test_values.append("y") self.test_values.append("g") - diff --git a/pilot/connections/rdbms/py_study/test_cls_2.py b/pilot/connections/rdbms/py_study/test_cls_2.py index e911f0542..1fb4d5e88 100644 --- a/pilot/connections/rdbms/py_study/test_cls_2.py +++ b/pilot/connections/rdbms/py_study/test_cls_2.py @@ -3,13 +3,15 @@ from pydantic import BaseModel from test_cls_base import TestBase from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + class Test2(TestBase): - test_2_values:List = [] - mode:str = "789" + test_2_values: List = [] + mode: str = "789" + def write(self): self.test_values.append(1) self.test_values.append(2) self.test_values.append(3) self.test_2_values.append("x") self.test_2_values.append("y") - self.test_2_values.append("z") \ No newline at end of file + self.test_2_values.append("z") diff --git a/pilot/connections/rdbms/py_study/test_cls_base.py b/pilot/connections/rdbms/py_study/test_cls_base.py index b8377c73d..676c1f2a5 100644 --- a/pilot/connections/rdbms/py_study/test_cls_base.py +++ b/pilot/connections/rdbms/py_study/test_cls_base.py @@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union class TestBase(BaseModel, ABC): test_values: List = [] - mode:str = "123" + mode: str = "123" def test(self): - print(self.__class__.__name__ + ":" ) + print(self.__class__.__name__ + ":") print(self.test_values) - print(self.mode) \ No newline at end of file + print(self.mode) diff --git a/pilot/embedding_engine/knowledge_embedding.py b/pilot/embedding_engine/knowledge_embedding.py index 5a8a2f944..fd28b938b 100644 --- a/pilot/embedding_engine/knowledge_embedding.py +++ b/pilot/embedding_engine/knowledge_embedding.py @@ -39,7 +39,9 @@ class KnowledgeEmbedding: return self.knowledge_embedding_client.read_batch() def init_knowledge_embedding(self): - return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config) + return get_knowledge_embedding( + self.knowledge_type, self.knowledge_source, self.vector_store_config + ) def similar_search(self, text, topk): vector_client = VectorStoreConnector( @@ -56,3 +58,9 @@ class KnowledgeEmbedding: CFG.VECTOR_STORE_TYPE, self.vector_store_config ) return vector_client.vector_name_exists() + + def delete_by_ids(self, ids): + vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) + vector_client.delete_by_ids(ids=ids) diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py index 649afa4ab..830968504 100644 --- a/pilot/memory/chat_history/base.py +++ b/pilot/memory/chat_history/base.py @@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC): def clear(self) -> None: """Clear session memory from the local file""" - - def conv_list(self, user_name:str=None) -> None: + def conv_list(self, user_name: str = None) -> None: """get user's conversation list""" pass - diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index b24546d19..d97232217 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -14,13 +14,12 @@ from pilot.common.formatting import MyEncoder default_db_path = os.path.join(os.getcwd(), "message") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") -table_name = 'chat_history' +table_name = "chat_history" CFG = Config() class DuckdbHistoryMemory(BaseChatHistoryMemory): - def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id os.makedirs(default_db_path, exist_ok=True) @@ -28,15 +27,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): self.__init_chat_history_tables() def __init_chat_history_tables(self): - # 检查表是否存在 - result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", - [table_name]).fetchall() + result = self.connect.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] + ).fetchall() if not result: # 如果表不存在,则创建新表 self.connect.execute( - "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)") + "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)" + ) def __get_messages_by_conv_uid(self, conv_uid: str): cursor = self.connect.cursor() @@ -58,23 +58,46 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): conversations.append(once_message) cursor = self.connect.cursor() if context: - cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id]) + cursor.execute( + "UPDATE chat_history set messages=? where conv_uid=?", + [ + json.dumps( + conversations_to_dict(conversations), + ensure_ascii=False, + indent=4, + ), + self.chat_seesion_id, + ], + ) else: - cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)]) + cursor.execute( + "INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", + [ + self.chat_seesion_id, + "", + json.dumps( + conversations_to_dict(conversations), + ensure_ascii=False, + indent=4, + ), + ], + ) cursor.commit() self.connect.commit() def clear(self) -> None: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() self.connect.commit() def delete(self) -> bool: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() return True @@ -83,7 +106,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() if user_name: - cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) + cursor.execute( + "SELECT * FROM chat_history where user_name=? limit 20", [user_name] + ) else: cursor.execute("SELECT * FROM chat_history limit 20") # 获取查询结果字段名 @@ -99,10 +124,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return [] - - def get_messages(self)-> List[OnceConversation]: + def get_messages(self) -> List[OnceConversation]: cursor = self.connect.cursor() - cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) context = cursor.fetchone() if context: return json.loads(context[0]) diff --git a/pilot/memory/chat_history/mem_history.py b/pilot/memory/chat_history/mem_history.py index a46428e75..c53595d3d 100644 --- a/pilot/memory/chat_history/mem_history.py +++ b/pilot/memory/chat_history/mem_history.py @@ -11,7 +11,7 @@ from pilot.scene.message import ( conversation_from_dict, conversations_to_dict, ) -from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList +from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList CFG = Config() @@ -19,7 +19,6 @@ CFG = Config() class MemHistoryMemory(BaseChatHistoryMemory): histroies_map = FixedSizeDict(100) - def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id self.histroies_map.update({chat_session_id: []}) diff --git a/pilot/model/cache/__init__.py b/pilot/model/cache/__init__.py index 6a2d6fc7e..37dd0025b 100644 --- a/pilot/model/cache/__init__.py +++ b/pilot/model/cache/__init__.py @@ -1,4 +1,4 @@ from .base import Cache from .disk_cache import DiskCache from .memory_cache import InMemoryCache -from .gpt_cache import GPTCache \ No newline at end of file +from .gpt_cache import GPTCache diff --git a/pilot/model/cache/base.py b/pilot/model/cache/base.py index d8c3d5851..d0b814f84 100644 --- a/pilot/model/cache/base.py +++ b/pilot/model/cache/base.py @@ -3,8 +3,8 @@ import hashlib from typing import Any, Dict from abc import ABC, abstractmethod + class Cache(ABC): - def create(self, key: str) -> bool: pass @@ -24,4 +24,4 @@ class Cache(ABC): @abstractmethod def __contains__(self, key: str) -> bool: """see if we can return a cached value for the passed key""" - pass \ No newline at end of file + pass diff --git a/pilot/model/cache/disk_cache.py b/pilot/model/cache/disk_cache.py index c461a37ae..25895dc9f 100644 --- a/pilot/model/cache/disk_cache.py +++ b/pilot/model/cache/disk_cache.py @@ -3,15 +3,15 @@ import diskcache import platformdirs from pilot.model.cache import Cache + class DiskCache(Cache): """DiskCache is a cache that uses diskcache lib. - https://github.com/grantjenks/python-diskcache + https://github.com/grantjenks/python-diskcache """ + def __init__(self, llm_name: str): self._diskcache = diskcache.Cache( - os.path.join( - platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache" - ) + os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache") ) def __getitem__(self, key: str) -> str: @@ -22,6 +22,6 @@ class DiskCache(Cache): def __contains__(self, key: str) -> bool: return key in self._diskcache - + def clear(self): - self._diskcache.clear() \ No newline at end of file + self._diskcache.clear() diff --git a/pilot/model/cache/gpt_cache.py b/pilot/model/cache/gpt_cache.py index 0fc680510..ee4529a5c 100644 --- a/pilot/model/cache/gpt_cache.py +++ b/pilot/model/cache/gpt_cache.py @@ -9,22 +9,23 @@ try: except ImportError: pass + class GPTCache(Cache): - """ - GPTCache is a semantic cache that uses + """ + GPTCache is a semantic cache that uses """ def __init__(self, cache) -> None: """GPT Cache is a semantic cache that uses GPTCache lib.""" - + if isinstance(cache, str): _cache = Cache() init_similar_cache( data_dir=os.path.join( platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache" ), - cache_obj=_cache + cache_obj=_cache, ) else: _cache = cache @@ -41,4 +42,4 @@ class GPTCache(Cache): return get(key) is not None def create(self, llm: str, **kwargs: Dict[str, Any]) -> str: - pass \ No newline at end of file + pass diff --git a/pilot/model/cache/memory_cache.py b/pilot/model/cache/memory_cache.py index b5311a341..7960b3d2e 100644 --- a/pilot/model/cache/memory_cache.py +++ b/pilot/model/cache/memory_cache.py @@ -1,24 +1,23 @@ from typing import Dict, Any from pilot.model.cache import Cache + class InMemoryCache(Cache): - def __init__(self) -> None: "Initialize that stores things in memory." self._cache: Dict[str, Any] = {} def create(self, key: str) -> bool: - pass + pass def clear(self): return self._cache.clear() def __setitem__(self, key: str, value: str) -> None: self._cache[key] = value - + def __getitem__(self, key: str) -> str: return self._cache[key] - - def __contains__(self, key: str) -> bool: - return self._cache.get(key, None) is not None + def __contains__(self, key: str) -> bool: + return self._cache.get(key, None) is not None diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 376e17ae4..9a3d3d88e 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -12,16 +12,21 @@ from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo +from pilot.server.api_v1.api_view_model import ( + Result, + ConversationVo, + MessageVo, + ChatSceneVo, +) from pilot.configs.config import Config from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory -from pilot.configs.model_config import (LOGDIR) +from pilot.configs.model_config import LOGDIR from pilot.utils import build_logger -from pilot.scene.base_message import (BaseMessage) +from pilot.scene.base_message import BaseMessage from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.scene.message import OnceConversation @@ -40,19 +45,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE def __get_conv_user_message(conversations: dict): - messages = conversations['messages'] + messages = conversations["messages"] for item in messages: - if item['type'] == "human": - return item['data']['content'] + if item["type"] == "human": + return item["data"]["content"] return "" -@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo]) +@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) async def dialogue_list(response: Response, user_id: str = None): # 设置CORS头部信息 - response.headers['Access-Control-Allow-Origin'] = '*' - response.headers['Access-Control-Allow-Methods'] = 'GET' - response.headers['Access-Control-Request-Headers'] = 'content-type' + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "GET" + response.headers["Access-Control-Request-Headers"] = "content-type" dialogues: List = [] datas = DuckdbHistoryMemory.conv_list(user_id) @@ -63,26 +68,44 @@ async def dialogue_list(response: Response, user_id: str = None): conversations = json.loads(messages) first_conv: OnceConversation = conversations[0] - conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv), - chat_mode=first_conv['chat_mode']) + conv_vo: ConversationVo = ConversationVo( + conv_uid=conv_uid, + user_input=__get_conv_user_message(first_conv), + chat_mode=first_conv["chat_mode"], + ) dialogues.append(conv_vo) return Result[ConversationVo].succ(dialogues) -@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]]) +@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]]) async def dialogue_scenes(): scene_vos: List[ChatSceneVo] = [] - new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution] + new_modes: List[ChatScene] = [ + ChatScene.ChatDb, + ChatScene.ChatData, + ChatScene.ChatDashboard, + ChatScene.ChatKnowledge, + ChatScene.ChatExecution, + ] for scene in new_modes: - if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]: - scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param") + if not scene.value in [ + ChatScene.ChatNormal.value, + ChatScene.InnerChatDBSummary.value, + ]: + scene_vo = ChatSceneVo( + chat_scene=scene.value, + scene_name=scene.name, + param_title="Selection Param", + ) scene_vos.append(scene_vo) return Result.succ(scene_vos) -@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo]) -async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None): +@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) +async def dialogue_new( + chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None +): unique_id = uuid.uuid1() return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) @@ -90,7 +113,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str def get_db_list(): db = CFG.local_db dbs = db.get_database_list() - params:dict = {} + params: dict = {} for name in dbs: params.update({name: name}) return params @@ -108,7 +131,7 @@ def knowledge_list(): return knowledge_service.get_knowledge_space(request) -@router.post('/v1/chat/mode/params/list', response_model=Result[dict]) +@router.post("/v1/chat/mode/params/list", response_model=Result[dict]) async def params_list(chat_mode: str = ChatScene.ChatNormal.value): if ChatScene.ChatDb.value == chat_mode: return Result.succ(get_db_list()) @@ -124,14 +147,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value): return Result.succ(None) -@router.post('/v1/chat/dialogue/delete') +@router.post("/v1/chat/dialogue/delete") async def dialogue_delete(con_uid: str): history_mem = DuckdbHistoryMemory(con_uid) history_mem.delete() return Result.succ(None) -@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo]) +@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo]) async def dialogue_history_messages(con_uid: str): print(f"dialogue_history_messages:{con_uid}") message_vos: List[MessageVo] = [] @@ -140,17 +163,21 @@ async def dialogue_history_messages(con_uid: str): history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: for once in history_messages: - once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']] + once_message_vos = [ + message2Vo(element, once["chat_order"]) for element in once["messages"] + ] message_vos.extend(once_message_vos) return Result.succ(message_vos) -@router.post('/v1/chat/completions') +@router.post("/v1/chat/completions") async def chat_completions(dialogue: ConversationVo = Body()): print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") if not ChatScene.is_valid_mode(dialogue.chat_mode): - raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) + raise StopAsyncIteration( + Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!") + ) chat_param = { "chat_session_id": dialogue.conv_uid, @@ -188,7 +215,9 @@ def stream_generator(chat): model_response = chat.stream_call() for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) chat.current_message.add_ai_message(msg) yield msg # chat.current_message.add_ai_message(msg) @@ -206,7 +235,9 @@ def stream_generator(chat): def message2Vo(message: dict, order) -> MessageVo: # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 - return MessageVo(role=message['type'], context=message['data']['content'], order=order) + return MessageVo( + role=message["type"], context=message["data"]["content"], order=order + ) def non_stream_response(chat): @@ -214,38 +245,38 @@ def non_stream_response(chat): return chat.nostream_call() -@router.get('/v1/db/types', response_model=Result[str]) +@router.get("/v1/db/types", response_model=Result[str]) async def db_types(): return Result.succ(["mysql", "duckdb"]) -@router.get('/v1/db/list', response_model=Result[str]) +@router.get("/v1/db/list", response_model=Result[str]) async def db_list(): db = CFG.local_db dbs = db.get_database_list() return Result.succ(dbs) -@router.get('/v1/knowledge/list') +@router.get("/v1/knowledge/list") async def knowledge_list(): return ["test1", "test2"] -@router.post('/v1/knowledge/add') +@router.post("/v1/knowledge/add") async def knowledge_add(): return ["test1", "test2"] -@router.post('/v1/knowledge/delete') +@router.post("/v1/knowledge/delete") async def knowledge_delete(): return ["test1", "test2"] -@router.get('/v1/knowledge/types') +@router.get("/v1/knowledge/types") async def knowledge_types(): return ["test1", "test2"] -@router.get('/v1/knowledge/detail') +@router.get("/v1/knowledge/detail") async def knowledge_detail(): return ["test1", "test2"] diff --git a/pilot/openapi/api_v1/api_view_model.py b/pilot/openapi/api_v1/api_view_model.py index f2d58599b..02eb70ddd 100644 --- a/pilot/openapi/api_v1/api_view_model.py +++ b/pilot/openapi/api_v1/api_view_model.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from typing import TypeVar, Union, List, Generic, Any -T = TypeVar('T') +T = TypeVar("T") class Result(Generic[T], BaseModel): @@ -24,15 +24,17 @@ class Result(Generic[T], BaseModel): class ChatSceneVo(BaseModel): - chat_scene: str = Field(..., description="chat_scene") - scene_name: str = Field(..., description="chat_scene name show for user") - param_title: str = Field(..., description="chat_scene required parameter title") + chat_scene: str = Field(..., description="chat_scene") + scene_name: str = Field(..., description="chat_scene name show for user") + param_title: str = Field(..., description="chat_scene required parameter title") + class ConversationVo(BaseModel): """ dialogue_uid """ - conv_uid: str = Field(..., description="dialogue uid") + + conv_uid: str = Field(..., description="dialogue uid") """ user input """ @@ -44,7 +46,7 @@ class ConversationVo(BaseModel): """ the scene of chat """ - chat_mode: str = Field(..., description="the scene of chat ") + chat_mode: str = Field(..., description="the scene of chat ") """ chat scene select param @@ -52,11 +54,11 @@ class ConversationVo(BaseModel): select_param: str = None - class MessageVo(BaseModel): """ - role that sends out the current message + role that sends out the current message """ + role: str """ current message @@ -70,4 +72,3 @@ class MessageVo(BaseModel): time the current message was sent """ time_stamp: Any = None - diff --git a/pilot/openapi/knowledge/document_chunk_dao.py b/pilot/openapi/knowledge/document_chunk_dao.py index e9a994e66..cb728e85c 100644 --- a/pilot/openapi/knowledge/document_chunk_dao.py +++ b/pilot/openapi/knowledge/document_chunk_dao.py @@ -10,8 +10,10 @@ from pilot.configs.config import Config CFG = Config() Base = declarative_base() + + class DocumentChunkEntity(Base): - __tablename__ = 'document_chunk' + __tablename__ = "document_chunk" id = Column(Integer, primary_key=True) document_id = Column(Integer) doc_name = Column(String(100)) @@ -29,43 +31,55 @@ class DocumentChunkDao: def __init__(self): database = "knowledge_management" self.db_engine = create_engine( - f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', - echo=True) + f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}", + echo=True, + ) self.Session = sessionmaker(bind=self.db_engine) - def create_documents_chunks(self, documents:List): + def create_documents_chunks(self, documents: List): session = self.Session() docs = [ DocumentChunkEntity( - doc_name=document.doc_name, - doc_type=document.doc_type, - document_id=document.document_id, - content=document.content or "", - meta_info=document.meta_info or "", - gmt_created=datetime.now(), - gmt_modified=datetime.now() + doc_name=document.doc_name, + doc_type=document.doc_type, + document_id=document.document_id, + content=document.content or "", + meta_info=document.meta_info or "", + gmt_created=datetime.now(), + gmt_modified=datetime.now(), ) - for document in documents] + for document in documents + ] session.add_all(docs) session.commit() session.close() - def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20): + def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20): session = self.Session() document_chunks = session.query(DocumentChunkEntity) if query.id is not None: document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) if query.document_id is not None: - document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id) + document_chunks = document_chunks.filter( + DocumentChunkEntity.document_id == query.document_id + ) if query.doc_type is not None: - document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type) + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_type == query.doc_type + ) if query.doc_name is not None: - document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name) + document_chunks = document_chunks.filter( + DocumentChunkEntity.doc_name == query.doc_name + ) if query.meta_info is not None: - document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info) + document_chunks = document_chunks.filter( + DocumentChunkEntity.meta_info == query.meta_info + ) document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc()) - document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size) + document_chunks = document_chunks.offset((page - 1) * page_size).limit( + page_size + ) result = document_chunks.all() return result diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py index cf368136f..bebbc8a3f 100644 --- a/pilot/openapi/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.request.knowledge_request import ( KnowledgeQueryRequest, - KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, + KnowledgeQueryResponse, + KnowledgeDocumentRequest, + DocumentSyncRequest, + ChunkQueryRequest, + DocumentQueryRequest, ) from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest @@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest): def document_list(space_name: str, query_request: DocumentQueryRequest): print(f"/document/list params: {space_name}, {query_request}") try: - return Result.succ(knowledge_space_service.get_knowledge_documents( - space_name, - query_request - )) + return Result.succ( + knowledge_space_service.get_knowledge_documents(space_name, query_request) + ) except Exception as e: return Result.faild(code="E000X", msg=f"document list error {e}") @router.post("/knowledge/{space_name}/document/upload") -def document_sync(space_name: str, file: UploadFile = File(...)): +async def document_sync(space_name: str, file: UploadFile = File(...)): print(f"/document/upload params: {space_name}") try: with NamedTemporaryFile(delete=False) as tmp: @@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest): knowledge_space_service.sync_knowledge_document( space_name=space_name, doc_ids=request.doc_ids ) - Result.succ([]) + return Result.succ([]) except Exception as e: return Result.faild(code="E000X", msg=f"document sync error {e}") @@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest): def document_list(space_name: str, query_request: ChunkQueryRequest): print(f"/document/list params: {space_name}, {query_request}") try: - return Result.succ(knowledge_space_service.get_document_chunks( - query_request - )) + return Result.succ(knowledge_space_service.get_document_chunks(query_request)) except Exception as e: return Result.faild(code="E000X", msg=f"document chunk list error {e}") diff --git a/pilot/openapi/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py index 276907b69..1f7afc401 100644 --- a/pilot/openapi/knowledge/knowledge_document_dao.py +++ b/pilot/openapi/knowledge/knowledge_document_dao.py @@ -12,7 +12,7 @@ Base = declarative_base() class KnowledgeDocumentEntity(Base): - __tablename__ = 'knowledge_document' + __tablename__ = "knowledge_document" id = Column(Integer, primary_key=True) doc_name = Column(String(100)) doc_type = Column(String(100)) @@ -21,23 +21,25 @@ class KnowledgeDocumentEntity(Base): status = Column(String(100)) last_sync = Column(String(100)) content = Column(Text) + result = Column(Text) vector_ids = Column(Text) gmt_created = Column(DateTime) gmt_modified = Column(DateTime) def __repr__(self): - return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" class KnowledgeDocumentDao: def __init__(self): database = "knowledge_management" self.db_engine = create_engine( - f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', - echo=True) + f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}", + echo=True, + ) self.Session = sessionmaker(bind=self.db_engine) - def create_knowledge_document(self, document:KnowledgeDocumentEntity): + def create_knowledge_document(self, document: KnowledgeDocumentEntity): session = self.Session() knowledge_document = KnowledgeDocumentEntity( doc_name=document.doc_name, @@ -47,9 +49,10 @@ class KnowledgeDocumentDao: status=document.status, last_sync=document.last_sync, content=document.content or "", + result=document.result or "", vector_ids=document.vector_ids, gmt_created=datetime.now(), - gmt_modified=datetime.now() + gmt_modified=datetime.now(), ) session.add(knowledge_document) session.commit() @@ -60,28 +63,42 @@ class KnowledgeDocumentDao: session = self.Session() knowledge_documents = session.query(KnowledgeDocumentEntity) if query.id is not None: - knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id) + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.id == query.id + ) if query.doc_name is not None: - knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name) + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_name == query.doc_name + ) if query.doc_type is not None: - knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type) + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.doc_type == query.doc_type + ) if query.space is not None: - knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space) + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.space == query.space + ) if query.status is not None: - knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status) + knowledge_documents = knowledge_documents.filter( + KnowledgeDocumentEntity.status == query.status + ) - knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc()) - knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size) + knowledge_documents = knowledge_documents.order_by( + KnowledgeDocumentEntity.id.desc() + ) + knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit( + page_size + ) result = knowledge_documents.all() return result - def update_knowledge_document(self, document:KnowledgeDocumentEntity): + def update_knowledge_document(self, document: KnowledgeDocumentEntity): session = self.Session() updated_space = session.merge(document) session.commit() return updated_space.id - def delete_knowledge_document(self, document_id:int): + def delete_knowledge_document(self, document_id: int): cursor = self.conn.cursor() query = "DELETE FROM knowledge_document WHERE id = %s" cursor.execute(query, (document_id,)) diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index b6713dff4..c41630ede 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -4,17 +4,24 @@ from datetime import datetime from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding -from pilot.embedding_engine.knowledge_type import KnowledgeType from pilot.logs import logger -from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao +from pilot.openapi.knowledge.document_chunk_dao import ( + DocumentChunkEntity, + DocumentChunkDao, +) from pilot.openapi.knowledge.knowledge_document_dao import ( KnowledgeDocumentDao, KnowledgeDocumentEntity, ) -from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity +from pilot.openapi.knowledge.knowledge_space_dao import ( + KnowledgeSpaceDao, + KnowledgeSpaceEntity, +) from pilot.openapi.knowledge.request.knowledge_request import ( KnowledgeSpaceRequest, - KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest, + KnowledgeDocumentRequest, + DocumentQueryRequest, + ChunkQueryRequest, ) from enum import Enum @@ -23,7 +30,7 @@ knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() document_chunk_dao = DocumentChunkDao() -CFG=Config() +CFG = Config() class SyncStatus(Enum): @@ -53,10 +60,7 @@ class KnowledgeService: """create knowledge document""" def create_knowledge_document(self, space, request: KnowledgeDocumentRequest): - query = KnowledgeDocumentEntity( - doc_name=request.doc_name, - space=space - ) + query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space) documents = knowledge_document_dao.get_knowledge_documents(query) if len(documents) > 0: raise Exception(f"document name:{request.doc_name} have already named") @@ -74,26 +78,27 @@ class KnowledgeService: """get knowledge space""" - def get_knowledge_space(self, request:KnowledgeSpaceRequest): + def get_knowledge_space(self, request: KnowledgeSpaceRequest): query = KnowledgeSpaceEntity( - name=request.name, - vector_type=request.vector_type, - owner=request.owner + name=request.name, vector_type=request.vector_type, owner=request.owner ) return knowledge_space_dao.get_knowledge_space(query) """get knowledge get_knowledge_documents""" - def get_knowledge_documents(self, space, request:DocumentQueryRequest): + def get_knowledge_documents(self, space, request: DocumentQueryRequest): query = KnowledgeDocumentEntity( doc_name=request.doc_name, doc_type=request.doc_type, space=space, status=request.status, ) - return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size) + return knowledge_document_dao.get_knowledge_documents( + query, page=request.page, page_size=request.page_size + ) """sync knowledge document chunk into vector store""" + def sync_knowledge_document(self, space_name, doc_ids): for doc_id in doc_ids: query = KnowledgeDocumentEntity( @@ -101,12 +106,14 @@ class KnowledgeService: space=space_name, ) doc = knowledge_document_dao.get_knowledge_documents(query)[0] - client = KnowledgeEmbedding(knowledge_source=doc.content, - knowledge_type=doc.doc_type.upper(), - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - vector_store_config={ - "vector_store_name": space_name, - }) + client = KnowledgeEmbedding( + knowledge_source=doc.content, + knowledge_type=doc.doc_type.upper(), + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config={ + "vector_store_name": space_name, + }, + ) chunk_docs = client.read() # update document status doc.status = SyncStatus.RUNNING.name @@ -114,9 +121,12 @@ class KnowledgeService: doc.gmt_modified = datetime.now() knowledge_document_dao.update_knowledge_document(doc) # async doc embeddings - thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc)) + thread = threading.Thread( + target=self.async_doc_embedding, args=(client, chunk_docs, doc) + ) thread.start() - #save chunk details + logger.info(f"begin save document chunks, doc:{doc.doc_name}") + # save chunk details chunk_entities = [ DocumentChunkEntity( doc_name=doc.doc_name, @@ -125,9 +135,10 @@ class KnowledgeService: content=chunk_doc.page_content, meta_info=str(chunk_doc.metadata), gmt_created=datetime.now(), - gmt_modified=datetime.now() + gmt_modified=datetime.now(), ) - for chunk_doc in chunk_docs] + for chunk_doc in chunk_docs + ] document_chunk_dao.create_documents_chunks(chunk_entities) return True @@ -145,26 +156,30 @@ class KnowledgeService: return knowledge_space_dao.delete_knowledge_space(space_id) """get document chunks""" - def get_document_chunks(self, request:ChunkQueryRequest): + + def get_document_chunks(self, request: ChunkQueryRequest): query = DocumentChunkEntity( id=request.id, document_id=request.document_id, doc_name=request.doc_name, - doc_type=request.doc_type + doc_type=request.doc_type, + ) + return document_chunk_dao.get_document_chunks( + query, page=request.page, page_size=request.page_size ) - return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size) def async_doc_embedding(self, client, chunk_docs, doc): - logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}") + logger.info( + f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}" + ) try: vector_ids = client.knowledge_embedding_batch(chunk_docs) doc.status = SyncStatus.FINISHED.name - doc.content = "embedding success" + doc.result = "document embedding success" doc.vector_ids = ",".join(vector_ids) + logger.info(f"async document embedding, success:{doc.doc_name}") except Exception as e: doc.status = SyncStatus.FAILED.name - doc.content = str(e) - + doc.result = "document embedding failed" + str(e) + logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}") return knowledge_document_dao.update_knowledge_document(doc) - - diff --git a/pilot/openapi/knowledge/knowledge_space_dao.py b/pilot/openapi/knowledge/knowledge_space_dao.py index 31894d6ac..16def1e99 100644 --- a/pilot/openapi/knowledge/knowledge_space_dao.py +++ b/pilot/openapi/knowledge/knowledge_space_dao.py @@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker CFG = Config() Base = declarative_base() + + class KnowledgeSpaceEntity(Base): - __tablename__ = 'knowledge_space' + __tablename__ = "knowledge_space" id = Column(Integer, primary_key=True) name = Column(String(100)) vector_type = Column(String(100)) @@ -27,10 +29,13 @@ class KnowledgeSpaceEntity(Base): class KnowledgeSpaceDao: def __init__(self): database = "knowledge_management" - self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True) + self.db_engine = create_engine( + f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}", + echo=True, + ) self.Session = sessionmaker(bind=self.db_engine) - def create_knowledge_space(self, space:KnowledgeSpaceRequest): + def create_knowledge_space(self, space: KnowledgeSpaceRequest): session = self.Session() knowledge_space = KnowledgeSpaceEntity( name=space.name, @@ -38,43 +43,61 @@ class KnowledgeSpaceDao: desc=space.desc, owner=space.owner, gmt_created=datetime.now(), - gmt_modified=datetime.now() + gmt_modified=datetime.now(), ) session.add(knowledge_space) session.commit() session.close() - def get_knowledge_space(self, query:KnowledgeSpaceEntity): + def get_knowledge_space(self, query: KnowledgeSpaceEntity): session = self.Session() knowledge_spaces = session.query(KnowledgeSpaceEntity) if query.id is not None: - knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id) + knowledge_spaces = knowledge_spaces.filter( + KnowledgeSpaceEntity.id == query.id + ) if query.name is not None: - knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name) + knowledge_spaces = knowledge_spaces.filter( + KnowledgeSpaceEntity.name == query.name + ) if query.vector_type is not None: - knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type) + knowledge_spaces = knowledge_spaces.filter( + KnowledgeSpaceEntity.vector_type == query.vector_type + ) if query.desc is not None: - knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc) + knowledge_spaces = knowledge_spaces.filter( + KnowledgeSpaceEntity.desc == query.desc + ) if query.owner is not None: - knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner) + knowledge_spaces = knowledge_spaces.filter( + KnowledgeSpaceEntity.owner == query.owner + ) if query.gmt_created is not None: - knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created) + knowledge_spaces = knowledge_spaces.filter( + KnowledgeSpaceEntity.gmt_created == query.gmt_created + ) if query.gmt_modified is not None: - knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified) + knowledge_spaces = knowledge_spaces.filter( + KnowledgeSpaceEntity.gmt_modified == query.gmt_modified + ) - knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc()) + knowledge_spaces = knowledge_spaces.order_by( + KnowledgeSpaceEntity.gmt_created.desc() + ) result = knowledge_spaces.all() return result - def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity): + def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity): cursor = self.conn.cursor() query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s" - cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id)) + cursor.execute( + query, (space.name, space.vector_type, space.desc, space.owner, space_id) + ) self.conn.commit() cursor.close() - def delete_knowledge_space(self, space_id:int): + def delete_knowledge_space(self, space_id: int): cursor = self.conn.cursor() query = "DELETE FROM knowledge_space WHERE id = %s" cursor.execute(query, (space_id,)) diff --git a/pilot/openapi/knowledge/request/knowledge_request.py b/pilot/openapi/knowledge/request/knowledge_request.py index bf68b06f1..1c5916f7c 100644 --- a/pilot/openapi/knowledge/request/knowledge_request.py +++ b/pilot/openapi/knowledge/request/knowledge_request.py @@ -30,17 +30,19 @@ class KnowledgeDocumentRequest(BaseModel): """doc_type: doc type""" doc_type: str """content: content""" - content: str + content: str = None """text_chunk_size: text_chunk_size""" # text_chunk_size: int + class DocumentQueryRequest(BaseModel): """doc_name: doc path""" + doc_name: str = None """doc_type: doc type""" - doc_type: str= None + doc_type: str = None """status: status""" - status: str= None + status: str = None """page: page""" page: int = 1 """page_size: page size""" @@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel): class DocumentSyncRequest(BaseModel): """doc_ids: doc ids""" + doc_ids: List + class ChunkQueryRequest(BaseModel): """id: id""" + id: int = None """document_id: doc id""" document_id: int = None diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py index 4d876aa51..e930927c9 100644 --- a/pilot/prompts/example_base.py +++ b/pilot/prompts/example_base.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from pilot.common.schema import ExampleType + class ExampleSelector(BaseModel, ABC): examples: List[List] use_example: bool = False diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 65e107d11..475f82ea4 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser from pilot.common.schema import SeparatorStyle from pilot.prompts.example_base import ExampleSelector + def jinja2_formatter(template: str, **kwargs: Any) -> str: """Format a template using jinja2.""" try: diff --git a/pilot/scene/base.py b/pilot/scene/base.py index cec443beb..d0bb99255 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -1,11 +1,13 @@ from enum import Enum + class Scene: def __init__(self, code, describe, is_inner): self.code = code self.describe = describe self.is_inner = is_inner + class ChatScene(Enum): ChatWithDbExecute = "chat_with_db_execute" ChatWithDbQA = "chat_with_db_qa" @@ -19,9 +21,8 @@ class ChatScene(Enum): ChatDashboard = "chat_dashboard" ChatKnowledge = "chat_knowledge" ChatDb = "chat_db" - ChatData= "chat_data" + ChatData = "chat_data" @staticmethod def is_valid_mode(mode): return any(mode == item.value for item in ChatScene) - diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index d6628a19a..c9ada6b29 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -104,7 +104,9 @@ class BaseChat(ABC): ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) - self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.current_message.start_date = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) # TODO self.current_message.tokens = 0 current_prompt = None @@ -200,8 +202,11 @@ class BaseChat(ABC): # }""" self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) - + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response( + ai_response_text + ) + ) result = self.do_action(prompt_define_response) diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 19a59c39a..a9f84e431 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -12,7 +12,10 @@ from pilot.common.markdown_text import ( generate_htm_table, ) from pilot.scene.chat_db.auto_execute.prompt import prompt -from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData +from pilot.scene.chat_dashboard.data_preparation.report_schma import ( + ChartData, + ReportData, +) CFG = Config() @@ -22,9 +25,7 @@ class ChatDashboard(BaseChat): report_name: str """Number of results to return from the query""" - def __init__( - self, chat_session_id, db_name, user_input, report_name - ): + def __init__(self, chat_session_id, db_name, user_input, report_name): """ """ super().__init__( chat_mode=ChatScene.ChatWithDbExecute, @@ -51,7 +52,7 @@ class ChatDashboard(BaseChat): "input": self.current_user_input, "dialect": self.database.dialect, "table_info": self.database.table_simple_info(self.db_connect), - "supported_chat_type": "" #TODO + "supported_chat_type": "" # TODO # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } return input_values @@ -68,7 +69,6 @@ class ChatDashboard(BaseChat): # TODO 修复流程 print(str(e)) - chart_datas.append(chart_data) report_data.conv_uid = self.chat_session_id @@ -77,5 +77,3 @@ class ChatDashboard(BaseChat): report_data.charts = chart_datas return report_data - - diff --git a/pilot/scene/chat_dashboard/data_preparation/report_schma.py b/pilot/scene/chat_dashboard/data_preparation/report_schma.py index 9323aacc3..84468f02e 100644 --- a/pilot/scene/chat_dashboard/data_preparation/report_schma.py +++ b/pilot/scene/chat_dashboard/data_preparation/report_schma.py @@ -12,11 +12,7 @@ class ChartData(BaseModel): class ReportData(BaseModel): - conv_uid:str - template_name:str - template_introduce:str + conv_uid: str + template_name: str + template_introduce: str charts: List[ChartData] - - - - diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py index 975196978..079b5f59a 100644 --- a/pilot/scene/chat_dashboard/out_parser.py +++ b/pilot/scene/chat_dashboard/out_parser.py @@ -10,9 +10,9 @@ from pilot.configs.model_config import LOGDIR class ChartItem(NamedTuple): sql: str - title:str + title: str thoughts: str - showcase:str + showcase: str logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log") @@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser): response = json.loads(clean_str) chart_items = List[ChartItem] for item in response: - chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"])) + chart_items.append( + ChartItem( + item["sql"], item["title"], item["thoughts"], item["showcase"] + ) + ) return chart_items def parse_view_response(self, speak, data) -> str: diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index e44144d4d..481ecac22 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical Ensure the response is correct json and can be parsed by Python json.loads """ -RESPONSE_FORMAT = [{ - "sql": "data analysis SQL", - "title": "Data Analysis Title", - "showcase": "What type of charts to show", - "thoughts": "Current thinking and value of data analysis" -}] +RESPONSE_FORMAT = [ + { + "sql": "data analysis SQL", + "title": "Data Analysis Title", + "showcase": "What type of charts to show", + "thoughts": "Current thinking and value of data analysis", + } +] PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 0a87f51f6..613c660c7 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat): """Number of results to return from the query""" - def __init__( - self, chat_session_id, db_name, user_input - ): + def __init__(self, chat_session_id, db_name, user_input): """ """ super().__init__( chat_mode=ChatScene.ChatWithDbExecute, diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 3d6cd0db4..3c8d0eb37 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat): """Number of results to return from the query""" - def __init__( - self, chat_session_id, db_name, user_input - ): + def __init__(self, chat_session_id, db_name, user_input): """ """ super().__init__( chat_mode=ChatScene.ChatWithDbQA, @@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat): "table_info": table_info, } return input_values - - diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index 6cd71b39c..e41de3abd 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector ## Two examples are defined by default EXAMPLES = [ - [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}], - [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}] + [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}], + [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}], ] example = ExampleSelector(examples=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index b5bc38bb2..ceb5db09d 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -36,7 +36,6 @@ RESPONSE_FORMAT = { } - EXAMPLE_TYPE = ExampleType.ONE_SHOT PROMPT_SEP = SeparatorStyle.SINGLE.value ### Whether the model service is streaming output @@ -49,8 +48,10 @@ prompt = PromptTemplate( template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_STREAM_OUT, - output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT), - example=example + output_parser=PluginChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT + ), + example=example, ) CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 370b5518d..9a67b4d75 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat): """Number of results to return from the query""" - def __init__( - self, chat_session_id, user_input, knowledge_name - ): + def __init__(self, chat_session_id, user_input, knowledge_name): """ """ super().__init__( chat_mode=ChatScene.ChatNewKnowledge, @@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat): input_values = {"context": context, "question": self.current_user_input} return input_values - @property def chat_type(self) -> str: return ChatScene.ChatNewKnowledge.value diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 8052de910..51587a048 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -29,7 +29,7 @@ class ChatDefaultKnowledge(BaseChat): """Number of results to return from the query""" - def __init__(self, chat_session_id, user_input): + def __init__(self, chat_session_id, user_input): """ """ super().__init__( chat_mode=ChatScene.ChatDefaultKnowledge, @@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat): ) return input_values - - @property def chat_type(self) -> str: return ChatScene.ChatDefaultKnowledge.value diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py index b4dcc536f..c1e8d279a 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat): } return input_values - @property def chat_type(self) -> str: return ChatScene.InnerChatDBSummary.value diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index bdb964214..aab4bdb9f 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat): input_values = {"context": context, "question": self.current_user_input} return input_values - @property def chat_type(self) -> str: return ChatScene.ChatUrlKnowledge.value diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 321c4d8eb..404a9347b 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -29,7 +29,7 @@ class ChatKnowledge(BaseChat): """Number of results to return from the query""" - def __init__(self, chat_session_id, user_input, knowledge_space): + def __init__(self, chat_session_id, user_input, knowledge_space): """ """ super().__init__( chat_mode=ChatScene.ChatKnowledge, @@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat): ) return input_values - - @property def chat_type(self) -> str: return ChatScene.ChatKnowledge.value diff --git a/pilot/scene/message.py b/pilot/scene/message.py index a2a894fe8..ae32bbfe7 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -85,7 +85,6 @@ class OnceConversation: self.messages.clear() self.session_id = None - def get_user_message(self): for once in self.messages: if isinstance(once, HumanMessage): diff --git a/pilot/server/__init__.py b/pilot/server/__init__.py index 0435c3679..ac72fc637 100644 --- a/pilot/server/__init__.py +++ b/pilot/server/__init__.py @@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): load_dotenv(verbose=True, override=True) del load_dotenv - - diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 3030d1fdc..367671d68 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -27,7 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter CFG = Config() - class ModelWorker: def __init__(self, model_path, model_name, device, num_gpus=1): if model_path.endswith("/"): @@ -106,6 +105,7 @@ worker = ModelWorker( app = FastAPI() from pilot.openapi.knowledge.knowledge_controller import router + app.include_router(router) origins = [ @@ -119,9 +119,10 @@ app.add_middleware( allow_origins=origins, allow_credentials=True, allow_methods=["*"], - allow_headers=["*"] + allow_headers=["*"], ) + class PromptRequest(BaseModel): prompt: str temperature: float diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 9513c0c5b..fd7401833 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -56,7 +56,7 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles @@ -107,12 +107,16 @@ knowledge_qa_type_list = [ add_knowledge_base_dialogue, ] + def swagger_monkey_patch(*args, **kwargs): return get_swagger_ui_html( - *args, **kwargs, - swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', - swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + *args, + **kwargs, + swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js", + swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css", ) + + applications.get_swagger_ui_html = swagger_monkey_patch app = FastAPI() @@ -360,14 +364,18 @@ def http_bot( response = chat.stream_call() for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) - state.messages[-1][-1] =msg + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) + state.messages[-1][-1] = msg chat.current_message.add_ai_message(msg) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 chat.memory.append(chat.current_message) except Exception as e: print(traceback.format_exc()) - state.messages[-1][-1] = f"""ERROR!{str(e)} """ + state.messages[-1][ + -1 + ] = f"""ERROR!{str(e)} """ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 @@ -693,8 +701,12 @@ def signal_handler(sig, frame): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) - parser.add_argument('-new', '--new', action='store_true', help='enable new http mode') + parser.add_argument( + "--model_list_mode", type=str, default="once", choices=["once", "reload"] + ) + parser.add_argument( + "-new", "--new", action="store_true", help="enable new http mode" + ) # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") @@ -702,27 +714,24 @@ if __name__ == "__main__": parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--share", default=False, action="store_true") - # init server config args = parser.parse_args() server_init(args) if args.new: import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) else: ### Compatibility mode starts the old version server by default demo = build_webdemo() demo.queue( - concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + concurrency_count=args.concurrency_count, + status_update_rate=10, + api_open=False, ).launch( server_name=args.host, server_port=args.port, share=args.share, max_threads=200, ) - - - - - diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 6caccc474..cf13b3017 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -38,7 +38,9 @@ class KnowledgeEmbedding: return self.knowledge_embedding_client.read_batch() def init_knowledge_embedding(self): - return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config) + return get_knowledge_embedding( + self.file_type.upper(), self.file_path, self.vector_store_config + ) def similar_search(self, text, topk): vector_client = VectorStoreConnector(