feat(editor): ChatExcel

ChatExcel devlop part 3
This commit is contained in:
yhjun1026 2023-08-22 13:55:02 +08:00
parent 5bbe47d715
commit 7e22d0d1b7
8 changed files with 66 additions and 20 deletions

View File

@ -50,11 +50,11 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
if df.size <= 0:
raise ValueError("No Data")
plt.rcParams["font.family"] = ["sans-serif"]
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
rc = {'font.sans-serif': "Microsoft Yahei"}
sns.set(context="notebook", style="whitegrid", color_codes=True, rc=rc)
plt.subplots(figsize=(8, 5), dpi=100)
sns.barplot(df, x=df[columns[0]], y=df[columns[1]])
plt.title("")
buf = io.BytesIO()

View File

@ -1,9 +1,13 @@
import uuid
import asyncio
import os
import shutil
from fastapi import (
APIRouter,
Request,
File,
UploadFile,
Form,
Body,
BackgroundTasks,
)
@ -11,6 +15,7 @@ from fastapi import (
from fastapi.responses import StreamingResponse
from fastapi.exceptions import RequestValidationError
from typing import List
from tempfile import NamedTemporaryFile
from pilot.openapi.api_view_model import (
Result,
@ -31,8 +36,7 @@ from pilot.utils import build_logger
from pilot.common.schema import DBType
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
from pilot.openapi.base import validation_exception_handler
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
router = APIRouter()
CFG = Config()
@ -159,7 +163,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
):
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)
@ -181,6 +185,37 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
return Result.succ(None)
@router.post("/v1/chat/mode/params/file/load")
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
print(f"params_load: {conv_uid},{chat_mode}")
try:
if doc_file:
## file save
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
with NamedTemporaryFile(
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode), delete=False
) as tmp:
tmp.write(await doc_file.read())
tmp_path = tmp.name
shutil.move(
tmp_path,
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename
),
)
## chat prepare
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename)
chat: BaseChat = get_chat_instance(dialogue)
resp = chat.prepare()
### refresh messages
return dialogue_history_messages(conv_uid)
except Exception as e:
return Result.faild(code="E000X", msg=f"File Load Error {e}")
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid)
@ -203,7 +238,8 @@ async def dialogue_history_messages(con_uid: str):
message_vos.extend(once_message_vos)
return Result.succ(message_vos)
def get_chat_instance(dialogue: ConversationVo = Body())-> BaseChat:
def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
logger.info(f"get_chat_instance:{dialogue}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value()
@ -230,7 +266,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
logger.info(f"chat_prepare:{dialogue}")
## check conv_uid
chat: BaseChat = get_chat_instance(dialogue)
if len(chat.history_message) >0:
if len(chat.history_message) > 0:
return Result.succ(None)
resp = chat.prepare()
return Result.succ(resp)
@ -263,7 +299,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)
async def no_stream_generator(chat):
msg = chat.nostream_call()
msg = msg.replace("\n", "\\n")

View File

@ -33,7 +33,6 @@ class ChatScene(Enum):
code = "excel_learning",
name = "Excel Learning",
describe = "Analyze and summarize your excel files.",
param_types=["File Select"],
is_inner = True,
)
ChatExcel = Scene(

View File

@ -92,7 +92,8 @@ class BaseChat(ABC):
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
if select_param:
self.current_message.param_type = chat_mode.param_types()[0]
if len(chat_mode.param_types()) > 0:
self.current_message.param_type = chat_mode.param_types()[0]
self.current_message.param_value = select_param
self.current_tokens_used: int = 0

View File

@ -25,6 +25,8 @@ class ChatExcel(BaseChat):
chat_retention_rounds = 2
def __init__(self, chat_session_id, user_input, select_param: str = ""):
chat_mode = ChatScene.ChatExcel
## TODO TEST
select_param = "/Users/tuyang.yhj/Downloads/example.xlsx"
self.excel_file_path = select_param
self.excel_reader = ExcelReader(select_param)
@ -75,9 +77,11 @@ class ChatExcel(BaseChat):
def prepare(self):
logger.info(f"{self.chat_mode} prepare start!")
if len(self.history_message) > 0:
return None
chat_param = {
"chat_session_id": self.chat_session_id,
"user_input": self.excel_reader.excel_file_name + " analysis",
"user_input": "[" + self.excel_reader.excel_file_name +"]" + self.excel_reader.extension + " analysis",
"select_param": self.excel_file_path
}
learn_chat = ExcelLearning(**chat_param)

View File

@ -28,7 +28,10 @@ class ExcelReader:
file_name = os.path.basename(file_path)
file_name_without_extension = os.path.splitext(file_name)[0]
self.excel_file_name = file_name_without_extension
self.extension = os.path.splitext(file_name)[1]
self.table_name = file_name_without_extension
# write data in duckdb
self.db.register(self.table_name, self.df)
@ -49,3 +52,4 @@ class ExcelReader:
def get_sample_data(self):
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;')

View File

@ -45,13 +45,15 @@ class ChatWithDbAutoExecute(BaseChat):
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient()
try:
table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
except Exception as e:
print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info()
# try:
# table_infos = client.get_db_summary(
# dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
# )
# except Exception as e:
# print("db summary find error!" + str(e))
# table_infos = self.database.table_simple_info()
#
table_infos = self.database.table_simple_info()
input_values = {
"input": self.current_user_input,

View File

@ -29,7 +29,8 @@ from fastapi.middleware.cors import CORSMiddleware
from pilot.server.knowledge.api import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
logging.basicConfig(level=logging.INFO)