diff --git a/pilot/connections/rdbms/py_study/study_data.py b/pilot/connections/rdbms/py_study/study_data.py index c19c39f33..c83c52acc 100644 --- a/pilot/connections/rdbms/py_study/study_data.py +++ b/pilot/connections/rdbms/py_study/study_data.py @@ -7,4 +7,4 @@ if __name__ == "__main__": connect = CFG.local_db.get_session("gpt-user") datas = CFG.local_db.run(connect, "SELECT * FROM users; ") - print(datas) \ No newline at end of file + print(datas) diff --git a/pilot/connections/rdbms/py_study/study_duckdb.py b/pilot/connections/rdbms/py_study/study_duckdb.py index 6e41aa480..20e75f38c 100644 --- a/pilot/connections/rdbms/py_study/study_duckdb.py +++ b/pilot/connections/rdbms/py_study/study_duckdb.py @@ -9,6 +9,8 @@ if __name__ == "__main__": if os.path.isfile("../../../message/chat_history.db"): cursor = duckdb.connect("../../../message/chat_history.db").cursor() # cursor.execute("SELECT * FROM chat_history limit 20") - cursor.execute("SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'") + cursor.execute( + "SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'" + ) data = cursor.fetchall() print(data) diff --git a/pilot/memory/chat_history/base.py b/pilot/memory/chat_history/base.py index ec4af1a14..0f79fb19e 100644 --- a/pilot/memory/chat_history/base.py +++ b/pilot/memory/chat_history/base.py @@ -26,10 +26,9 @@ class BaseChatHistoryMemory(ABC): """Retrieve the messages from the local file""" @abstractmethod - def create(self, user_name:str) -> None: + def create(self, user_name: str) -> None: """Append the message to the record in the local file""" - @abstractmethod def append(self, message: OnceConversation) -> None: """Append the message to the record in the local file""" diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index 6ccba5c4f..e3cf01efc 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -36,7 +36,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): if not result: # 如果表不存在,则创建新表 self.connect.execute( - "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)") + "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, chat_mode VARCHAR(50), summary VARCHAR(255), user_name VARCHAR(100), messages TEXT)" + ) def __get_messages_by_conv_uid(self, conv_uid: str): cursor = self.connect.cursor() @@ -59,7 +60,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): cursor = self.connect.cursor() cursor.execute( "INSERT INTO chat_history(conv_uid, chat_mode summary, user_name, messages)VALUES(?,?,?,?,?)", - [self.chat_seesion_id, chat_mode, summary, user_name, ""]) + [self.chat_seesion_id, chat_mode, summary, user_name, ""], + ) cursor.commit() self.connect.commit() except Exception as e: @@ -80,7 +82,14 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): else: cursor.execute( "INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)", - [self.chat_seesion_id, once_message.chat_mode, once_message.get_user_conv().content, "",json.dumps(conversations, ensure_ascii=False)]) + [ + self.chat_seesion_id, + once_message.chat_mode, + once_message.get_user_conv().content, + "", + json.dumps(conversations, ensure_ascii=False), + ], + ) cursor.commit() self.connect.commit() diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 267deda52..de1cf0f14 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -3,7 +3,15 @@ import json import asyncio import time import os -from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks +from fastapi import ( + APIRouter, + Request, + Body, + status, + HTTPException, + Response, + BackgroundTasks, +) from fastapi.responses import JSONResponse, HTMLResponse from fastapi.responses import StreamingResponse, FileResponse @@ -12,7 +20,12 @@ from fastapi.exceptions import RequestValidationError from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo +from pilot.openapi.api_v1.api_view_model import ( + Result, + ConversationVo, + MessageVo, + ChatSceneVo, +) from pilot.configs.config import Config from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest @@ -36,6 +49,7 @@ model_semaphore = None global_counter = 0 static_file_path = os.path.join(os.getcwd(), "server/static") + async def validation_exception_handler(request: Request, exc: RequestValidationError): message = "" for error in exc.errors(): @@ -82,6 +96,7 @@ def knowledge_list(): params.update({space.name: space.name}) return params + @router.get("/") async def read_main(): return FileResponse(f"{static_file_path}/test.html") @@ -102,8 +117,6 @@ async def dialogue_list(response: Response, user_id: str = None): summary = item.get("summary") chat_mode = item.get("chat_mode") - - conv_vo: ConversationVo = ConversationVo( conv_uid=conv_uid, user_input=summary, @@ -138,14 +151,14 @@ async def dialogue_scenes(): return Result.succ(scene_vos) - @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None ): conv_vo = __new_conversation(chat_mode, user_id) return Result.succ(conv_vo) + @router.post("/v1/chat/mode/params/list", response_model=Result[dict]) async def params_list(chat_mode: str = ChatScene.ChatNormal.value): if ChatScene.ChatWithDbQA.value == chat_mode: @@ -232,11 +245,19 @@ async def chat_completions(dialogue: ConversationVo = Body()): } if not chat.prompt_template.stream_out: - return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', - background=background_tasks) + return StreamingResponse( + no_stream_generator(chat), + headers=headers, + media_type="text/event-stream", + background=background_tasks, + ) else: - return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', - background=background_tasks) + return StreamingResponse( + stream_generator(chat), + headers=headers, + media_type="text/plain", + background=background_tasks, + ) def release_model_semaphore(): @@ -254,14 +275,18 @@ async def stream_generator(chat): if not CFG.NEW_SERVER_MODE: for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) msg = msg.replace("\n", "\\n") yield f"data:{msg}\n\n" await asyncio.sleep(0.1) else: for chunk in model_response: if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) msg = msg.replace("\n", "\\n") yield f"data:{msg}\n\n" @@ -273,4 +298,6 @@ async def stream_generator(chat): def message2Vo(message: dict, order) -> MessageVo: - return MessageVo(role=message['type'], context=message['data']['content'], order=order) + return MessageVo( + role=message["type"], context=message["data"]["content"], order=order + ) diff --git a/pilot/openapi/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py index f99b81a72..afbdd4906 100644 --- a/pilot/openapi/knowledge/knowledge_document_dao.py +++ b/pilot/openapi/knowledge/knowledge_document_dao.py @@ -123,6 +123,7 @@ class KnowledgeDocumentDao: updated_space = session.merge(document) session.commit() return updated_space.id + # # def delete_knowledge_document(self, document_id: int): # cursor = self.conn.cursor() diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index 49e6a0fa3..aafe0a590 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -115,8 +115,13 @@ class KnowledgeService: space=space_name, ) doc = knowledge_document_dao.get_knowledge_documents(query)[0] - if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name: - raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync") + if ( + doc.status == SyncStatus.RUNNING.name + or doc.status == SyncStatus.FINISHED.name + ): + raise Exception( + f" doc:{doc.doc_name} status is {doc.status}, can not sync" + ) client = KnowledgeEmbedding( knowledge_source=doc.content, knowledge_type=doc.doc_type.upper(), diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index dff8528ac..fc1eaaa39 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -107,7 +107,9 @@ class BaseChat(ABC): ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) - self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.current_message.start_date = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) self.current_message.tokens = 0 if self.prompt_template.template: diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index dd614399c..805f13eaf 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -65,12 +65,12 @@ class ChatDashboard(BaseChat): try: datas = self.database.run(self.db_connect, chart_item.sql) chart_data: ChartData = ChartData() - chart_data.chart_sql = chart_item['sql'] - chart_data.chart_type = chart_item['showcase'] - chart_data.chart_name = chart_item['title'] - chart_data.chart_desc = chart_item['thoughts'] + chart_data.chart_sql = chart_item["sql"] + chart_data.chart_type = chart_item["showcase"] + chart_data.chart_name = chart_item["title"] + chart_data.chart_desc = chart_item["thoughts"] chart_data.column_name = datas[0] - chart_data.values =datas + chart_data.values = datas except Exception as e: # TODO 修复流程 print(str(e)) diff --git a/pilot/scene/chat_db/auto_execute/example.py b/pilot/scene/chat_db/auto_execute/example.py index b4c248d65..73fea6f51 100644 --- a/pilot/scene/chat_db/auto_execute/example.py +++ b/pilot/scene/chat_db/auto_execute/example.py @@ -1,5 +1,6 @@ from pilot.prompts.example_base import ExampleSelector from pilot.common.schema import ExampleType + ## Two examples are defined by default EXAMPLES = [ { @@ -34,4 +35,6 @@ EXAMPLES = [ }, ] -sql_data_example = ExampleSelector(examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value) +sql_data_example = ExampleSelector( + examples_record=EXAMPLES, use_example=True, type=ExampleType.ONE_SHOT.value +) diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index eaa45498c..fe490fb8d 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -9,6 +9,8 @@ from pilot.configs.model_config import LOGDIR from pilot.configs.config import Config CFG = Config() + + class SqlAction(NamedTuple): sql: str thoughts: Dict @@ -35,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser): df = pd.DataFrame(data[1:], columns=data[0]) if CFG.NEW_SERVER_MODE: html = df.to_html(index=False, escape=False, sparsify=False) - html = ''.join(html.split()) + html = "".join(html.split()) else: table_style = """