mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
feat(core): APP use new SDK component (#1050)
This commit is contained in:
@@ -1,37 +1,30 @@
|
||||
import json
|
||||
import time
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
)
|
||||
|
||||
from typing import List
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
from dbgpt.app.scene import ChatFactory
|
||||
|
||||
from dbgpt.app.openapi.api_view_model import (
|
||||
Result,
|
||||
from dbgpt.app.openapi.api_v1.editor.service import EditorService
|
||||
from dbgpt.app.openapi.api_v1.editor.sql_editor import (
|
||||
ChartRunData,
|
||||
DataNode,
|
||||
SqlRunData,
|
||||
)
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.app.openapi.editor_view_model import (
|
||||
ChatDbRounds,
|
||||
ChartList,
|
||||
ChartDetail,
|
||||
ChartList,
|
||||
ChatChartEditContext,
|
||||
ChatDbRounds,
|
||||
ChatSqlEditContext,
|
||||
DbTable,
|
||||
)
|
||||
|
||||
from dbgpt.app.openapi.api_v1.editor.sql_editor import (
|
||||
DataNode,
|
||||
ChartRunData,
|
||||
SqlRunData,
|
||||
)
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.app.scene import ChatFactory
|
||||
from dbgpt.app.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
from dbgpt.app.scene.chat_db.data_loader import DbDataLoader
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
|
||||
|
||||
router = APIRouter()
|
||||
@@ -41,6 +34,14 @@ CHAT_FACTORY = ChatFactory()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_conversation_serve() -> ConversationServe:
|
||||
return ConversationServe.get_instance(CFG.SYSTEM_APP)
|
||||
|
||||
|
||||
def get_edit_service() -> EditorService:
|
||||
return EditorService.get_instance(CFG.SYSTEM_APP)
|
||||
|
||||
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||
async def get_editor_tables(
|
||||
db_name: str, page_index: int, page_size: int, search_str: str = ""
|
||||
@@ -69,48 +70,21 @@ async def get_editor_tables(
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||
async def get_editor_sql_rounds(con_uid: str):
|
||||
async def get_editor_sql_rounds(
|
||||
con_uid: str, editor_service: EditorService = Depends(get_edit_service)
|
||||
):
|
||||
logger.info("get_editor_sql_rounds:{con_uid}")
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
result: List = []
|
||||
for once in history_messages:
|
||||
round_name: str = ""
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "human":
|
||||
round_name = element["data"]["content"]
|
||||
if once.get("param_value"):
|
||||
round: ChatDbRounds = ChatDbRounds(
|
||||
round=once["chat_order"],
|
||||
db_name=once["param_value"],
|
||||
round_name=round_name,
|
||||
)
|
||||
result.append(round)
|
||||
return Result.succ(result)
|
||||
return Result.succ(editor_service.get_editor_sql_rounds(con_uid))
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||
async def get_editor_sql(con_uid: str, round: int):
|
||||
async def get_editor_sql(
|
||||
con_uid: str, round: int, editor_service: EditorService = Depends(get_edit_service)
|
||||
):
|
||||
logger.info(f"get_editor_sql:{con_uid},{round}")
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
if int(once["chat_order"]) == round:
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "ai":
|
||||
logger.info(
|
||||
f'history ai json resp:{element["data"]["content"]}'
|
||||
)
|
||||
context = (
|
||||
element["data"]["content"]
|
||||
.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
)
|
||||
return Result.succ(json.loads(context))
|
||||
context = editor_service.get_editor_sql_by_round(con_uid, round)
|
||||
if context:
|
||||
return Result.succ(context)
|
||||
return Result.failed(msg="not have sql!")
|
||||
|
||||
|
||||
@@ -120,7 +94,7 @@ async def editor_sql_run(run_param: dict = Body()):
|
||||
db_name = run_param["db_name"]
|
||||
sql = run_param["sql"]
|
||||
if not db_name and not sql:
|
||||
return Result.failed("SQL run param error!")
|
||||
return Result.failed(msg="SQL run param error!")
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
|
||||
try:
|
||||
@@ -145,102 +119,43 @@ async def editor_sql_run(run_param: dict = Body()):
|
||||
|
||||
|
||||
@router.post("/v1/sql/editor/submit")
|
||||
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
async def sql_editor_submit(
|
||||
sql_edit_context: ChatSqlEditContext = Body(),
|
||||
editor_service: EditorService = Depends(get_edit_service),
|
||||
):
|
||||
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(sql_edit_context.conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
|
||||
|
||||
edit_round = list(
|
||||
filter(
|
||||
lambda x: x["chat_order"] == sql_edit_context.conv_round,
|
||||
history_messages,
|
||||
)
|
||||
)[0]
|
||||
if edit_round:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
db_resp = json.loads(element["data"]["content"])
|
||||
db_resp["thoughts"] = sql_edit_context.new_speak
|
||||
db_resp["sql"] = sql_edit_context.new_sql
|
||||
element["data"]["content"] = json.dumps(db_resp)
|
||||
if element["type"] == "view":
|
||||
data_loader = DbDataLoader()
|
||||
element["data"]["content"] = data_loader.get_table_view_by_conn(
|
||||
conn.run_to_df(sql_edit_context.new_sql),
|
||||
sql_edit_context.new_speak,
|
||||
)
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.failed(msg="Edit Failed!")
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
|
||||
try:
|
||||
editor_service.sql_editor_submit_and_save(sql_edit_context, conn)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error(f"edit sql exception!{str(e)}")
|
||||
return Result.failed(msg=f"Edit sql exception!{str(e)}")
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
|
||||
async def get_editor_chart_list(con_uid: str):
|
||||
async def get_editor_chart_list(
|
||||
con_uid: str,
|
||||
editor_service: EditorService = Depends(get_edit_service),
|
||||
):
|
||||
logger.info(
|
||||
f"get_editor_sql_rounds:{con_uid}",
|
||||
)
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
db_name = last_round["param_value"]
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
chart_list: ChartList = ChartList(
|
||||
round=last_round["chat_order"],
|
||||
db_name=db_name,
|
||||
charts=json.loads(element["data"]["content"]),
|
||||
)
|
||||
return Result.succ(chart_list)
|
||||
chart_list = editor_service.get_editor_chart_list(con_uid)
|
||||
if chart_list:
|
||||
return Result.succ(chart_list)
|
||||
return Result.failed(msg="Not have charts!")
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart_info(param: dict = Body()):
|
||||
async def get_editor_chart_info(
|
||||
param: dict = Body(), editor_service: EditorService = Depends(get_edit_service)
|
||||
):
|
||||
logger.info(f"get_editor_chart_info:{param}")
|
||||
conv_uid = param["con_uid"]
|
||||
chart_title = param["chart_title"]
|
||||
|
||||
chat_history_fac = ChatHistory()
|
||||
history_mem = chat_history_fac.get_store_instance(conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
db_name = last_round["param_value"]
|
||||
if not db_name:
|
||||
logger.error(
|
||||
"this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
return Result.failed(
|
||||
msg="this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"])
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(
|
||||
filter(lambda x: x["chart_name"] == chart_title, charts)
|
||||
)[0]
|
||||
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
detail: ChartDetail = ChartDetail(
|
||||
chart_uid=find_chart["chart_uid"],
|
||||
chart_type=find_chart["chart_type"],
|
||||
chart_desc=find_chart["chart_desc"],
|
||||
chart_sql=find_chart["chart_sql"],
|
||||
db_name=db_name,
|
||||
chart_name=find_chart["chart_name"],
|
||||
chart_value=find_chart["values"],
|
||||
table_value=conn.run(find_chart["chart_sql"]),
|
||||
)
|
||||
|
||||
return Result.succ(detail)
|
||||
return Result.failed(msg="Can't Find Chart Detail Info!")
|
||||
return editor_service.get_editor_chart_info(conv_uid, chart_title, CFG)
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||
|
190
dbgpt/app/openapi/api_v1/editor/service.py
Normal file
190
dbgpt/app/openapi/api_v1/editor/service.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.app.openapi.editor_view_model import (
|
||||
ChartDetail,
|
||||
ChartList,
|
||||
ChatDbRounds,
|
||||
ChatSqlEditContext,
|
||||
)
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core import BaseOutputParser
|
||||
from dbgpt.core.interface.message import (
|
||||
MessageStorageItem,
|
||||
StorageConversation,
|
||||
_split_messages_by_round,
|
||||
)
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditorService(BaseComponent):
|
||||
name = "dbgpt_app_editor_service"
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
self._system_app: SystemApp = system_app
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self._system_app = system_app
|
||||
|
||||
def conv_serve(self) -> ConversationServe:
|
||||
return ConversationServe.get_instance(self._system_app)
|
||||
|
||||
def get_storage_conv(self, conv_uid: str) -> StorageConversation:
|
||||
conv_serve: ConversationServe = self.conv_serve()
|
||||
return StorageConversation(
|
||||
conv_uid,
|
||||
conv_storage=conv_serve.conv_storage,
|
||||
message_storage=conv_serve.message_storage,
|
||||
)
|
||||
|
||||
def get_editor_sql_rounds(self, conv_uid: str) -> List[ChatDbRounds]:
|
||||
storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
|
||||
messages_by_round = _split_messages_by_round(storage_conv.messages)
|
||||
result: List[ChatDbRounds] = []
|
||||
for one_round_message in messages_by_round:
|
||||
if not one_round_message:
|
||||
continue
|
||||
for message in one_round_message:
|
||||
if message.type == "human":
|
||||
round_name = message.content
|
||||
if message.additional_kwargs.get("param_value"):
|
||||
chat_db_round: ChatDbRounds = ChatDbRounds(
|
||||
round=message.round_index,
|
||||
db_name=message.additional_kwargs.get("param_value"),
|
||||
round_name=round_name,
|
||||
)
|
||||
result.append(chat_db_round)
|
||||
|
||||
return result
|
||||
|
||||
def get_editor_sql_by_round(
|
||||
self, conv_uid: str, round_index: int
|
||||
) -> Optional[Dict]:
|
||||
storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
|
||||
messages_by_round = _split_messages_by_round(storage_conv.messages)
|
||||
for one_round_message in messages_by_round:
|
||||
if not one_round_message:
|
||||
continue
|
||||
for message in one_round_message:
|
||||
if message.type == "ai" and message.round_index == round_index:
|
||||
content = message.content
|
||||
logger.info(f"history ai json resp: {content}")
|
||||
# context = content.replace("\\n", " ").replace("\n", " ")
|
||||
context_dict = _parse_pure_dict(content)
|
||||
return context_dict
|
||||
return None
|
||||
|
||||
def sql_editor_submit_and_save(
|
||||
self, sql_edit_context: ChatSqlEditContext, connection: BaseConnect
|
||||
):
|
||||
storage_conv: StorageConversation = self.get_storage_conv(
|
||||
sql_edit_context.conv_uid
|
||||
)
|
||||
if not storage_conv.save_message_independent:
|
||||
raise ValueError(
|
||||
"Submit sql and save just support independent conversation mode(after v0.4.6)"
|
||||
)
|
||||
conv_serve: ConversationServe = self.conv_serve()
|
||||
messages_by_round = _split_messages_by_round(storage_conv.messages)
|
||||
to_update_messages = []
|
||||
for one_round_message in messages_by_round:
|
||||
if not one_round_message:
|
||||
continue
|
||||
if one_round_message[0].round_index == sql_edit_context.conv_round:
|
||||
for message in one_round_message:
|
||||
if message.type == "ai":
|
||||
db_resp = _parse_pure_dict(message.content)
|
||||
db_resp["thoughts"] = sql_edit_context.new_speak
|
||||
db_resp["sql"] = sql_edit_context.new_sql
|
||||
message.content = json.dumps(db_resp, ensure_ascii=False)
|
||||
to_update_messages.append(
|
||||
MessageStorageItem(
|
||||
storage_conv.conv_uid, message.index, message.to_dict()
|
||||
)
|
||||
)
|
||||
# TODO not support update view message now
|
||||
# if message.type == "view":
|
||||
# data_loader = DbDataLoader()
|
||||
# message.content = data_loader.get_table_view_by_conn(
|
||||
# connection.run_to_df(sql_edit_context.new_sql),
|
||||
# sql_edit_context.new_speak,
|
||||
# )
|
||||
# to_update_messages.append(
|
||||
# MessageStorageItem(
|
||||
# storage_conv.conv_uid, message.index, message.to_dict()
|
||||
# )
|
||||
# )
|
||||
if to_update_messages:
|
||||
conv_serve.message_storage.save_or_update_list(to_update_messages)
|
||||
return
|
||||
|
||||
def get_editor_chart_list(self, conv_uid: str) -> Optional[ChartList]:
|
||||
storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
|
||||
messages_by_round = _split_messages_by_round(storage_conv.messages)
|
||||
for one_round_message in messages_by_round:
|
||||
if not one_round_message:
|
||||
continue
|
||||
for message in one_round_message:
|
||||
if message.type == "ai":
|
||||
context_dict = _parse_pure_dict(message.content)
|
||||
chart_list: ChartList = ChartList(
|
||||
round=message.round_index,
|
||||
db_name=message.additional_kwargs.get("param_value"),
|
||||
charts=context_dict,
|
||||
)
|
||||
return chart_list
|
||||
|
||||
def get_editor_chart_info(
|
||||
self, conv_uid: str, chart_title: str, cfg: Config
|
||||
) -> Result[ChartDetail]:
|
||||
storage_conv: StorageConversation = self.get_storage_conv(conv_uid)
|
||||
messages_by_round = _split_messages_by_round(storage_conv.messages)
|
||||
for one_round_message in messages_by_round:
|
||||
if not one_round_message:
|
||||
continue
|
||||
for message in one_round_message:
|
||||
db_name = message.additional_kwargs.get("param_value")
|
||||
if not db_name:
|
||||
logger.error(
|
||||
"this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
return Result.failed(
|
||||
msg="this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
if message.type == "view":
|
||||
view_data: dict = _parse_pure_dict(message.content)
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(
|
||||
filter(lambda x: x["chart_name"] == chart_title, charts)
|
||||
)[0]
|
||||
|
||||
conn = cfg.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
detail: ChartDetail = ChartDetail(
|
||||
chart_uid=find_chart["chart_uid"],
|
||||
chart_type=find_chart["chart_type"],
|
||||
chart_desc=find_chart["chart_desc"],
|
||||
chart_sql=find_chart["chart_sql"],
|
||||
db_name=db_name,
|
||||
chart_name=find_chart["chart_name"],
|
||||
chart_value=find_chart["values"],
|
||||
table_value=conn.run(find_chart["chart_sql"]),
|
||||
)
|
||||
return Result.succ(detail)
|
||||
return Result.failed(msg="Can't Find Chart Detail Info!")
|
||||
|
||||
|
||||
def _parse_pure_dict(res_str: str) -> Dict:
|
||||
output_parser = BaseOutputParser()
|
||||
context = output_parser.parse_prompt_response(res_str)
|
||||
return json.loads(context)
|
@@ -1,4 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||
|
||||
|
Reference in New Issue
Block a user