From 7e22d0d1b792b8fe1bf9a9ef6e9a93f305063e1b Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 22 Aug 2023 13:55:02 +0800 Subject: [PATCH] feat(editor): ChatExcel ChatExcel devlop part 3 --- pilot/commands/disply_type/show_chart_gen.py | 6 +-- pilot/openapi/api_v1/api_v1.py | 47 ++++++++++++++++--- pilot/scene/base.py | 1 - pilot/scene/base_chat.py | 3 +- .../chat_excel/excel_analyze/chat.py | 6 ++- .../chat_data/chat_excel/excel_reader.py | 4 ++ pilot/scene/chat_db/auto_execute/chat.py | 16 ++++--- pilot/server/dbgpt_server.py | 3 +- 8 files changed, 66 insertions(+), 20 deletions(-) diff --git a/pilot/commands/disply_type/show_chart_gen.py b/pilot/commands/disply_type/show_chart_gen.py index 78b1ed3fd..fe974be86 100644 --- a/pilot/commands/disply_type/show_chart_gen.py +++ b/pilot/commands/disply_type/show_chart_gen.py @@ -50,11 +50,11 @@ def response_bar_chart(speak: str, df: DataFrame) -> str: if df.size <= 0: raise ValueError("No Data!") plt.rcParams["font.family"] = ["sans-serif"] - rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} - sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) - sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) + rc = {'font.sans-serif': "Microsoft Yahei"} + sns.set(context="notebook", style="whitegrid", color_codes=True, rc=rc) plt.subplots(figsize=(8, 5), dpi=100) sns.barplot(df, x=df[columns[0]], y=df[columns[1]]) + plt.title("") buf = io.BytesIO() diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index baa1c6263..f994d4c75 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -1,9 +1,13 @@ import uuid import asyncio import os +import shutil from fastapi import ( APIRouter, Request, + File, + UploadFile, + Form, Body, BackgroundTasks, ) @@ -11,6 +15,7 @@ from fastapi import ( from fastapi.responses import StreamingResponse from fastapi.exceptions import RequestValidationError from typing import List +from tempfile import NamedTemporaryFile from pilot.openapi.api_view_model import ( Result, @@ -31,8 +36,7 @@ from pilot.utils import build_logger from pilot.common.schema import DBType from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.scene.message import OnceConversation -from pilot.openapi.base import validation_exception_handler - +from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH router = APIRouter() CFG = Config() @@ -159,7 +163,7 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None ): conv_vo = __new_conversation(chat_mode, user_id) return Result.succ(conv_vo) @@ -181,6 +185,37 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()): return Result.succ(None) +@router.post("/v1/chat/mode/params/file/load") +async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)): + print(f"params_load: {conv_uid},{chat_mode}") + try: + if doc_file: + ## file save + if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)): + os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)) + with NamedTemporaryFile( + dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode), delete=False + ) as tmp: + tmp.write(await doc_file.read()) + tmp_path = tmp.name + shutil.move( + tmp_path, + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename + ), + ) + ## chat prepare + dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename) + chat: BaseChat = get_chat_instance(dialogue) + resp = chat.prepare() + + ### refresh messages + return dialogue_history_messages(conv_uid) + + except Exception as e: + return Result.faild(code="E000X", msg=f"File Load Error {e}") + + @router.post("/v1/chat/dialogue/delete") async def dialogue_delete(con_uid: str): history_mem = DuckdbHistoryMemory(con_uid) @@ -203,7 +238,8 @@ async def dialogue_history_messages(con_uid: str): message_vos.extend(once_message_vos) return Result.succ(message_vos) -def get_chat_instance(dialogue: ConversationVo = Body())-> BaseChat: + +def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: logger.info(f"get_chat_instance:{dialogue}") if not dialogue.chat_mode: dialogue.chat_mode = ChatScene.ChatNormal.value() @@ -230,7 +266,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()): logger.info(f"chat_prepare:{dialogue}") ## check conv_uid chat: BaseChat = get_chat_instance(dialogue) - if len(chat.history_message) >0: + if len(chat.history_message) > 0: return Result.succ(None) resp = chat.prepare() return Result.succ(resp) @@ -263,7 +299,6 @@ async def chat_completions(dialogue: ConversationVo = Body()): ) - async def no_stream_generator(chat): msg = chat.nostream_call() msg = msg.replace("\n", "\\n") diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 2bfa2b11e..eb9113e77 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -33,7 +33,6 @@ class ChatScene(Enum): code = "excel_learning", name = "Excel Learning", describe = "Analyze and summarize your excel files.", - param_types=["File Select"], is_inner = True, ) ChatExcel = Scene( diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 1a76a5795..0ead2f293 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -92,7 +92,8 @@ class BaseChat(ABC): self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value()) if select_param: - self.current_message.param_type = chat_mode.param_types()[0] + if len(chat_mode.param_types()) > 0: + self.current_message.param_type = chat_mode.param_types()[0] self.current_message.param_value = select_param self.current_tokens_used: int = 0 diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 7c8fb8745..83991f4f2 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -25,6 +25,8 @@ class ChatExcel(BaseChat): chat_retention_rounds = 2 def __init__(self, chat_session_id, user_input, select_param: str = ""): chat_mode = ChatScene.ChatExcel + ## TODO TEST + select_param = "/Users/tuyang.yhj/Downloads/example.xlsx" self.excel_file_path = select_param self.excel_reader = ExcelReader(select_param) @@ -75,9 +77,11 @@ class ChatExcel(BaseChat): def prepare(self): logger.info(f"{self.chat_mode} prepare start!") + if len(self.history_message) > 0: + return None chat_param = { "chat_session_id": self.chat_session_id, - "user_input": self.excel_reader.excel_file_name + " analysis!", + "user_input": "[" + self.excel_reader.excel_file_name +"]" + self.excel_reader.extension + " analysis!", "select_param": self.excel_file_path } learn_chat = ExcelLearning(**chat_param) diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index 35d6678f2..9a1aa6f2f 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -28,7 +28,10 @@ class ExcelReader: file_name = os.path.basename(file_path) file_name_without_extension = os.path.splitext(file_name)[0] + self.excel_file_name = file_name_without_extension + self.extension = os.path.splitext(file_name)[1] + self.table_name = file_name_without_extension # write data in duckdb self.db.register(self.table_name, self.df) @@ -49,3 +52,4 @@ class ExcelReader: def get_sample_data(self): return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;') + diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 1160e4e23..6d047e61e 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -45,13 +45,15 @@ class ChatWithDbAutoExecute(BaseChat): except ImportError: raise ValueError("Could not import DBSummaryClient. ") client = DBSummaryClient() - try: - table_infos = client.get_db_summary( - dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE - ) - except Exception as e: - print("db summary find error!" + str(e)) - table_infos = self.database.table_simple_info() + # try: + # table_infos = client.get_db_summary( + # dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE + # ) + # except Exception as e: + # print("db summary find error!" + str(e)) + # table_infos = self.database.table_simple_info() + # + table_infos = self.database.table_simple_info() input_values = { "input": self.current_user_input, diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index edfba8d8b..f8b78b31d 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -29,7 +29,8 @@ from fastapi.middleware.cors import CORSMiddleware from pilot.server.knowledge.api import router as knowledge_router -from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler +from pilot.openapi.api_v1.api_v1 import router as api_v1 +from pilot.openapi.base import validation_exception_handler from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 logging.basicConfig(level=logging.INFO)