mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 12:42:34 +00:00
feat(editor): ChatExcel
ChatExcel devlop part 3
This commit is contained in:
parent
5bbe47d715
commit
7e22d0d1b7
@ -50,11 +50,11 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
plt.rcParams["font.family"] = ["sans-serif"]
|
||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
||||
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
|
||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
||||
rc = {'font.sans-serif': "Microsoft Yahei"}
|
||||
sns.set(context="notebook", style="whitegrid", color_codes=True, rc=rc)
|
||||
plt.subplots(figsize=(8, 5), dpi=100)
|
||||
sns.barplot(df, x=df[columns[0]], y=df[columns[1]])
|
||||
|
||||
plt.title("")
|
||||
|
||||
buf = io.BytesIO()
|
||||
|
@ -1,9 +1,13 @@
|
||||
import uuid
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Request,
|
||||
File,
|
||||
UploadFile,
|
||||
Form,
|
||||
Body,
|
||||
BackgroundTasks,
|
||||
)
|
||||
@ -11,6 +15,7 @@ from fastapi import (
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from typing import List
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
@ -31,8 +36,7 @@ from pilot.utils import build_logger
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@ -159,7 +163,7 @@ async def dialogue_scenes():
|
||||
|
||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||
async def dialogue_new(
|
||||
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
|
||||
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
|
||||
):
|
||||
conv_vo = __new_conversation(chat_mode, user_id)
|
||||
return Result.succ(conv_vo)
|
||||
@ -181,6 +185,37 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post("/v1/chat/mode/params/file/load")
|
||||
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
|
||||
print(f"params_load: {conv_uid},{chat_mode}")
|
||||
try:
|
||||
if doc_file:
|
||||
## file save
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
|
||||
with NamedTemporaryFile(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode), delete=False
|
||||
) as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
tmp_path = tmp.name
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename
|
||||
),
|
||||
)
|
||||
## chat prepare
|
||||
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename)
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
resp = chat.prepare()
|
||||
|
||||
### refresh messages
|
||||
return dialogue_history_messages(conv_uid)
|
||||
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"File Load Error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/chat/dialogue/delete")
|
||||
async def dialogue_delete(con_uid: str):
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
@ -203,7 +238,8 @@ async def dialogue_history_messages(con_uid: str):
|
||||
message_vos.extend(once_message_vos)
|
||||
return Result.succ(message_vos)
|
||||
|
||||
def get_chat_instance(dialogue: ConversationVo = Body())-> BaseChat:
|
||||
|
||||
def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
||||
logger.info(f"get_chat_instance:{dialogue}")
|
||||
if not dialogue.chat_mode:
|
||||
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
||||
@ -230,7 +266,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
|
||||
logger.info(f"chat_prepare:{dialogue}")
|
||||
## check conv_uid
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
if len(chat.history_message) >0:
|
||||
if len(chat.history_message) > 0:
|
||||
return Result.succ(None)
|
||||
resp = chat.prepare()
|
||||
return Result.succ(resp)
|
||||
@ -263,7 +299,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def no_stream_generator(chat):
|
||||
msg = chat.nostream_call()
|
||||
msg = msg.replace("\n", "\\n")
|
||||
|
@ -33,7 +33,6 @@ class ChatScene(Enum):
|
||||
code = "excel_learning",
|
||||
name = "Excel Learning",
|
||||
describe = "Analyze and summarize your excel files.",
|
||||
param_types=["File Select"],
|
||||
is_inner = True,
|
||||
)
|
||||
ChatExcel = Scene(
|
||||
|
@ -92,7 +92,8 @@ class BaseChat(ABC):
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
||||
if select_param:
|
||||
self.current_message.param_type = chat_mode.param_types()[0]
|
||||
if len(chat_mode.param_types()) > 0:
|
||||
self.current_message.param_type = chat_mode.param_types()[0]
|
||||
self.current_message.param_value = select_param
|
||||
self.current_tokens_used: int = 0
|
||||
|
||||
|
@ -25,6 +25,8 @@ class ChatExcel(BaseChat):
|
||||
chat_retention_rounds = 2
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
chat_mode = ChatScene.ChatExcel
|
||||
## TODO TEST
|
||||
select_param = "/Users/tuyang.yhj/Downloads/example.xlsx"
|
||||
|
||||
self.excel_file_path = select_param
|
||||
self.excel_reader = ExcelReader(select_param)
|
||||
@ -75,9 +77,11 @@ class ChatExcel(BaseChat):
|
||||
|
||||
def prepare(self):
|
||||
logger.info(f"{self.chat_mode} prepare start!")
|
||||
if len(self.history_message) > 0:
|
||||
return None
|
||||
chat_param = {
|
||||
"chat_session_id": self.chat_session_id,
|
||||
"user_input": self.excel_reader.excel_file_name + " analysis!",
|
||||
"user_input": "[" + self.excel_reader.excel_file_name +"]" + self.excel_reader.extension + " analysis!",
|
||||
"select_param": self.excel_file_path
|
||||
}
|
||||
learn_chat = ExcelLearning(**chat_param)
|
||||
|
@ -28,7 +28,10 @@ class ExcelReader:
|
||||
file_name = os.path.basename(file_path)
|
||||
file_name_without_extension = os.path.splitext(file_name)[0]
|
||||
|
||||
|
||||
self.excel_file_name = file_name_without_extension
|
||||
self.extension = os.path.splitext(file_name)[1]
|
||||
|
||||
self.table_name = file_name_without_extension
|
||||
# write data in duckdb
|
||||
self.db.register(self.table_name, self.df)
|
||||
@ -49,3 +52,4 @@ class ExcelReader:
|
||||
|
||||
def get_sample_data(self):
|
||||
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;')
|
||||
|
||||
|
@ -45,13 +45,15 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient()
|
||||
try:
|
||||
table_infos = client.get_db_summary(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
table_infos = self.database.table_simple_info()
|
||||
# try:
|
||||
# table_infos = client.get_db_summary(
|
||||
# dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
# )
|
||||
# except Exception as e:
|
||||
# print("db summary find error!" + str(e))
|
||||
# table_infos = self.database.table_simple_info()
|
||||
#
|
||||
table_infos = self.database.table_simple_info()
|
||||
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
|
@ -29,7 +29,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from pilot.server.knowledge.api import router as knowledge_router
|
||||
|
||||
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
Loading…
Reference in New Issue
Block a user