feat(editor): ChatExcel

ChatExcel devlop part 3
This commit is contained in:
yhjun1026 2023-08-22 13:55:02 +08:00
parent 5bbe47d715
commit 7e22d0d1b7
8 changed files with 66 additions and 20 deletions

View File

@ -50,11 +50,11 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
if df.size <= 0: if df.size <= 0:
raise ValueError("No Data") raise ValueError("No Data")
plt.rcParams["font.family"] = ["sans-serif"] plt.rcParams["font.family"] = ["sans-serif"]
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} rc = {'font.sans-serif': "Microsoft Yahei"}
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) sns.set(context="notebook", style="whitegrid", color_codes=True, rc=rc)
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
plt.subplots(figsize=(8, 5), dpi=100) plt.subplots(figsize=(8, 5), dpi=100)
sns.barplot(df, x=df[columns[0]], y=df[columns[1]]) sns.barplot(df, x=df[columns[0]], y=df[columns[1]])
plt.title("") plt.title("")
buf = io.BytesIO() buf = io.BytesIO()

View File

@ -1,9 +1,13 @@
import uuid import uuid
import asyncio import asyncio
import os import os
import shutil
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Request, Request,
File,
UploadFile,
Form,
Body, Body,
BackgroundTasks, BackgroundTasks,
) )
@ -11,6 +15,7 @@ from fastapi import (
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from typing import List from typing import List
from tempfile import NamedTemporaryFile
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
Result, Result,
@ -31,8 +36,7 @@ from pilot.utils import build_logger
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
from pilot.openapi.base import validation_exception_handler from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
@ -181,6 +185,37 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
return Result.succ(None) return Result.succ(None)
@router.post("/v1/chat/mode/params/file/load")
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
print(f"params_load: {conv_uid},{chat_mode}")
try:
if doc_file:
## file save
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
with NamedTemporaryFile(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode), delete=False
) as tmp:
tmp.write(await doc_file.read())
tmp_path = tmp.name
shutil.move(
tmp_path,
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename
),
)
## chat prepare
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename)
chat: BaseChat = get_chat_instance(dialogue)
resp = chat.prepare()
### refresh messages
return dialogue_history_messages(conv_uid)
except Exception as e:
return Result.faild(code="E000X", msg=f"File Load Error {e}")
@router.post("/v1/chat/dialogue/delete") @router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str): async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
@ -203,7 +238,8 @@ async def dialogue_history_messages(con_uid: str):
message_vos.extend(once_message_vos) message_vos.extend(once_message_vos)
return Result.succ(message_vos) return Result.succ(message_vos)
def get_chat_instance(dialogue: ConversationVo = Body())-> BaseChat:
def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
logger.info(f"get_chat_instance:{dialogue}") logger.info(f"get_chat_instance:{dialogue}")
if not dialogue.chat_mode: if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value() dialogue.chat_mode = ChatScene.ChatNormal.value()
@ -230,7 +266,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
logger.info(f"chat_prepare:{dialogue}") logger.info(f"chat_prepare:{dialogue}")
## check conv_uid ## check conv_uid
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = get_chat_instance(dialogue)
if len(chat.history_message) >0: if len(chat.history_message) > 0:
return Result.succ(None) return Result.succ(None)
resp = chat.prepare() resp = chat.prepare()
return Result.succ(resp) return Result.succ(resp)
@ -263,7 +299,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
) )
async def no_stream_generator(chat): async def no_stream_generator(chat):
msg = chat.nostream_call() msg = chat.nostream_call()
msg = msg.replace("\n", "\\n") msg = msg.replace("\n", "\\n")

View File

@ -33,7 +33,6 @@ class ChatScene(Enum):
code = "excel_learning", code = "excel_learning",
name = "Excel Learning", name = "Excel Learning",
describe = "Analyze and summarize your excel files.", describe = "Analyze and summarize your excel files.",
param_types=["File Select"],
is_inner = True, is_inner = True,
) )
ChatExcel = Scene( ChatExcel = Scene(

View File

@ -92,6 +92,7 @@ class BaseChat(ABC):
self.history_message: List[OnceConversation] = self.memory.messages() self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value()) self.current_message: OnceConversation = OnceConversation(chat_mode.value())
if select_param: if select_param:
if len(chat_mode.param_types()) > 0:
self.current_message.param_type = chat_mode.param_types()[0] self.current_message.param_type = chat_mode.param_types()[0]
self.current_message.param_value = select_param self.current_message.param_value = select_param
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0

View File

@ -25,6 +25,8 @@ class ChatExcel(BaseChat):
chat_retention_rounds = 2 chat_retention_rounds = 2
def __init__(self, chat_session_id, user_input, select_param: str = ""): def __init__(self, chat_session_id, user_input, select_param: str = ""):
chat_mode = ChatScene.ChatExcel chat_mode = ChatScene.ChatExcel
## TODO TEST
select_param = "/Users/tuyang.yhj/Downloads/example.xlsx"
self.excel_file_path = select_param self.excel_file_path = select_param
self.excel_reader = ExcelReader(select_param) self.excel_reader = ExcelReader(select_param)
@ -75,9 +77,11 @@ class ChatExcel(BaseChat):
def prepare(self): def prepare(self):
logger.info(f"{self.chat_mode} prepare start!") logger.info(f"{self.chat_mode} prepare start!")
if len(self.history_message) > 0:
return None
chat_param = { chat_param = {
"chat_session_id": self.chat_session_id, "chat_session_id": self.chat_session_id,
"user_input": self.excel_reader.excel_file_name + " analysis", "user_input": "[" + self.excel_reader.excel_file_name +"]" + self.excel_reader.extension + " analysis",
"select_param": self.excel_file_path "select_param": self.excel_file_path
} }
learn_chat = ExcelLearning(**chat_param) learn_chat = ExcelLearning(**chat_param)

View File

@ -28,7 +28,10 @@ class ExcelReader:
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
file_name_without_extension = os.path.splitext(file_name)[0] file_name_without_extension = os.path.splitext(file_name)[0]
self.excel_file_name = file_name_without_extension self.excel_file_name = file_name_without_extension
self.extension = os.path.splitext(file_name)[1]
self.table_name = file_name_without_extension self.table_name = file_name_without_extension
# write data in duckdb # write data in duckdb
self.db.register(self.table_name, self.df) self.db.register(self.table_name, self.df)
@ -49,3 +52,4 @@ class ExcelReader:
def get_sample_data(self): def get_sample_data(self):
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;') return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;')

View File

@ -45,12 +45,14 @@ class ChatWithDbAutoExecute(BaseChat):
except ImportError: except ImportError:
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient() client = DBSummaryClient()
try: # try:
table_infos = client.get_db_summary( # table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE # dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
) # )
except Exception as e: # except Exception as e:
print("db summary find error!" + str(e)) # print("db summary find error!" + str(e))
# table_infos = self.database.table_simple_info()
#
table_infos = self.database.table_simple_info() table_infos = self.database.table_simple_info()
input_values = { input_values = {

View File

@ -29,7 +29,8 @@ from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router from pilot.server.knowledge.api import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)