mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 06:26:18 +00:00
feat(editor): ChatExcel
ChatExcel devlop part 3
This commit is contained in:
parent
5bbe47d715
commit
7e22d0d1b7
@ -50,11 +50,11 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
|
|||||||
if df.size <= 0:
|
if df.size <= 0:
|
||||||
raise ValueError("No Data!")
|
raise ValueError("No Data!")
|
||||||
plt.rcParams["font.family"] = ["sans-serif"]
|
plt.rcParams["font.family"] = ["sans-serif"]
|
||||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
rc = {'font.sans-serif': "Microsoft Yahei"}
|
||||||
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
|
sns.set(context="notebook", style="whitegrid", color_codes=True, rc=rc)
|
||||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
|
||||||
plt.subplots(figsize=(8, 5), dpi=100)
|
plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
sns.barplot(df, x=df[columns[0]], y=df[columns[1]])
|
sns.barplot(df, x=df[columns[0]], y=df[columns[1]])
|
||||||
|
|
||||||
plt.title("")
|
plt.title("")
|
||||||
|
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Request,
|
Request,
|
||||||
|
File,
|
||||||
|
UploadFile,
|
||||||
|
Form,
|
||||||
Body,
|
Body,
|
||||||
BackgroundTasks,
|
BackgroundTasks,
|
||||||
)
|
)
|
||||||
@ -11,6 +15,7 @@ from fastapi import (
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
from pilot.openapi.api_view_model import (
|
from pilot.openapi.api_view_model import (
|
||||||
Result,
|
Result,
|
||||||
@ -31,8 +36,7 @@ from pilot.utils import build_logger
|
|||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
from pilot.openapi.base import validation_exception_handler
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -159,7 +163,7 @@ async def dialogue_scenes():
|
|||||||
|
|
||||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||||
async def dialogue_new(
|
async def dialogue_new(
|
||||||
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
|
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
|
||||||
):
|
):
|
||||||
conv_vo = __new_conversation(chat_mode, user_id)
|
conv_vo = __new_conversation(chat_mode, user_id)
|
||||||
return Result.succ(conv_vo)
|
return Result.succ(conv_vo)
|
||||||
@ -181,6 +185,37 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
|||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/chat/mode/params/file/load")
|
||||||
|
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
|
||||||
|
print(f"params_load: {conv_uid},{chat_mode}")
|
||||||
|
try:
|
||||||
|
if doc_file:
|
||||||
|
## file save
|
||||||
|
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode)):
|
||||||
|
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode))
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode), delete=False
|
||||||
|
) as tmp:
|
||||||
|
tmp.write(await doc_file.read())
|
||||||
|
tmp_path = tmp.name
|
||||||
|
shutil.move(
|
||||||
|
tmp_path,
|
||||||
|
os.path.join(
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode, doc_file.filename
|
||||||
|
),
|
||||||
|
)
|
||||||
|
## chat prepare
|
||||||
|
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename)
|
||||||
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
|
resp = chat.prepare()
|
||||||
|
|
||||||
|
### refresh messages
|
||||||
|
return dialogue_history_messages(conv_uid)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"File Load Error {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/dialogue/delete")
|
@router.post("/v1/chat/dialogue/delete")
|
||||||
async def dialogue_delete(con_uid: str):
|
async def dialogue_delete(con_uid: str):
|
||||||
history_mem = DuckdbHistoryMemory(con_uid)
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
@ -203,7 +238,8 @@ async def dialogue_history_messages(con_uid: str):
|
|||||||
message_vos.extend(once_message_vos)
|
message_vos.extend(once_message_vos)
|
||||||
return Result.succ(message_vos)
|
return Result.succ(message_vos)
|
||||||
|
|
||||||
def get_chat_instance(dialogue: ConversationVo = Body())-> BaseChat:
|
|
||||||
|
def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
||||||
logger.info(f"get_chat_instance:{dialogue}")
|
logger.info(f"get_chat_instance:{dialogue}")
|
||||||
if not dialogue.chat_mode:
|
if not dialogue.chat_mode:
|
||||||
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
dialogue.chat_mode = ChatScene.ChatNormal.value()
|
||||||
@ -230,7 +266,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
|
|||||||
logger.info(f"chat_prepare:{dialogue}")
|
logger.info(f"chat_prepare:{dialogue}")
|
||||||
## check conv_uid
|
## check conv_uid
|
||||||
chat: BaseChat = get_chat_instance(dialogue)
|
chat: BaseChat = get_chat_instance(dialogue)
|
||||||
if len(chat.history_message) >0:
|
if len(chat.history_message) > 0:
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
resp = chat.prepare()
|
resp = chat.prepare()
|
||||||
return Result.succ(resp)
|
return Result.succ(resp)
|
||||||
@ -263,7 +299,6 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def no_stream_generator(chat):
|
async def no_stream_generator(chat):
|
||||||
msg = chat.nostream_call()
|
msg = chat.nostream_call()
|
||||||
msg = msg.replace("\n", "\\n")
|
msg = msg.replace("\n", "\\n")
|
||||||
|
@ -33,7 +33,6 @@ class ChatScene(Enum):
|
|||||||
code = "excel_learning",
|
code = "excel_learning",
|
||||||
name = "Excel Learning",
|
name = "Excel Learning",
|
||||||
describe = "Analyze and summarize your excel files.",
|
describe = "Analyze and summarize your excel files.",
|
||||||
param_types=["File Select"],
|
|
||||||
is_inner = True,
|
is_inner = True,
|
||||||
)
|
)
|
||||||
ChatExcel = Scene(
|
ChatExcel = Scene(
|
||||||
|
@ -92,7 +92,8 @@ class BaseChat(ABC):
|
|||||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
||||||
if select_param:
|
if select_param:
|
||||||
self.current_message.param_type = chat_mode.param_types()[0]
|
if len(chat_mode.param_types()) > 0:
|
||||||
|
self.current_message.param_type = chat_mode.param_types()[0]
|
||||||
self.current_message.param_value = select_param
|
self.current_message.param_value = select_param
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
|
|
||||||
|
@ -25,6 +25,8 @@ class ChatExcel(BaseChat):
|
|||||||
chat_retention_rounds = 2
|
chat_retention_rounds = 2
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||||
chat_mode = ChatScene.ChatExcel
|
chat_mode = ChatScene.ChatExcel
|
||||||
|
## TODO TEST
|
||||||
|
select_param = "/Users/tuyang.yhj/Downloads/example.xlsx"
|
||||||
|
|
||||||
self.excel_file_path = select_param
|
self.excel_file_path = select_param
|
||||||
self.excel_reader = ExcelReader(select_param)
|
self.excel_reader = ExcelReader(select_param)
|
||||||
@ -75,9 +77,11 @@ class ChatExcel(BaseChat):
|
|||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
logger.info(f"{self.chat_mode} prepare start!")
|
logger.info(f"{self.chat_mode} prepare start!")
|
||||||
|
if len(self.history_message) > 0:
|
||||||
|
return None
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": self.chat_session_id,
|
"chat_session_id": self.chat_session_id,
|
||||||
"user_input": self.excel_reader.excel_file_name + " analysis!",
|
"user_input": "[" + self.excel_reader.excel_file_name +"]" + self.excel_reader.extension + " analysis!",
|
||||||
"select_param": self.excel_file_path
|
"select_param": self.excel_file_path
|
||||||
}
|
}
|
||||||
learn_chat = ExcelLearning(**chat_param)
|
learn_chat = ExcelLearning(**chat_param)
|
||||||
|
@ -28,7 +28,10 @@ class ExcelReader:
|
|||||||
file_name = os.path.basename(file_path)
|
file_name = os.path.basename(file_path)
|
||||||
file_name_without_extension = os.path.splitext(file_name)[0]
|
file_name_without_extension = os.path.splitext(file_name)[0]
|
||||||
|
|
||||||
|
|
||||||
self.excel_file_name = file_name_without_extension
|
self.excel_file_name = file_name_without_extension
|
||||||
|
self.extension = os.path.splitext(file_name)[1]
|
||||||
|
|
||||||
self.table_name = file_name_without_extension
|
self.table_name = file_name_without_extension
|
||||||
# write data in duckdb
|
# write data in duckdb
|
||||||
self.db.register(self.table_name, self.df)
|
self.db.register(self.table_name, self.df)
|
||||||
@ -49,3 +52,4 @@ class ExcelReader:
|
|||||||
|
|
||||||
def get_sample_data(self):
|
def get_sample_data(self):
|
||||||
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;')
|
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;')
|
||||||
|
|
||||||
|
@ -45,13 +45,15 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("Could not import DBSummaryClient. ")
|
raise ValueError("Could not import DBSummaryClient. ")
|
||||||
client = DBSummaryClient()
|
client = DBSummaryClient()
|
||||||
try:
|
# try:
|
||||||
table_infos = client.get_db_summary(
|
# table_infos = client.get_db_summary(
|
||||||
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
# dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||||
)
|
# )
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print("db summary find error!" + str(e))
|
# print("db summary find error!" + str(e))
|
||||||
table_infos = self.database.table_simple_info()
|
# table_infos = self.database.table_simple_info()
|
||||||
|
#
|
||||||
|
table_infos = self.database.table_simple_info()
|
||||||
|
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
|
@ -29,7 +29,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from pilot.server.knowledge.api import router as knowledge_router
|
from pilot.server.knowledge.api import router as knowledge_router
|
||||||
|
|
||||||
|
|
||||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||||
|
from pilot.openapi.base import validation_exception_handler
|
||||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
Loading…
Reference in New Issue
Block a user