mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-29 21:49:35 +00:00
feat(editor): editor api devlop
editor api devlop part 2
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -17,17 +18,18 @@ from pilot.openapi.api_view_model import (
|
||||
)
|
||||
from pilot.openapi.editor_view_model import (
|
||||
ChatDbRounds,
|
||||
ChartList,
|
||||
ChartDetail,
|
||||
ChatChartEditContext,
|
||||
ChatSqlEditContext,
|
||||
DbTable
|
||||
)
|
||||
|
||||
from pilot.openapi.api_v1.editor.sql_editor import DataNode,ChartRunData,SqlRunData
|
||||
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
|
||||
from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@@ -83,14 +85,24 @@ async def get_editor_sql(con_uid: str, round: int):
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "ai":
|
||||
return Result.succ(json.loads(element["data"]["content"]))
|
||||
return Result.faild("没有获取到可用的SQL返回结构")
|
||||
return Result.faild(msg="not have sql!")
|
||||
|
||||
|
||||
@router.post("/v1/editor/sql/run", response_model=Result[List[dict]])
|
||||
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
||||
async def editor_sql_run(db_name: str, sql: str):
|
||||
logger.info("get_editor_sql_run:{},{}", db_name, sql)
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
return Result.succ(conn.run(sql))
|
||||
|
||||
start_time = time.time() * 1000
|
||||
colunms, sql_result = conn.query_ex(sql)
|
||||
# 计算执行耗时
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result
|
||||
)
|
||||
return Result.succ(sql_run_data)
|
||||
|
||||
|
||||
@router.post("/v1/sql/editor/submit", response_model=Result)
|
||||
@@ -99,34 +111,43 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0]
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
|
||||
|
||||
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0]
|
||||
if edit_round:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
element["data"]["content"]=""
|
||||
db_resp = json.loads(element["data"]["content"])
|
||||
db_resp['thoughts'] = sql_edit_context.new_speak
|
||||
db_resp['sql'] = sql_edit_context.new_sql
|
||||
element["data"]["content"] = json.dumps(db_resp)
|
||||
if element["type"] == "view":
|
||||
element["data"]["content"]=""
|
||||
data_loader = DbDataLoader()
|
||||
element["data"]["content"] = data_loader.get_table_view_by_conn(conn.run(sql_edit_context.new_sql),
|
||||
sql_edit_context.new_speak)
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.faild("Edit Faild!")
|
||||
return Result.faild(msg="Edit Faild!")
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartDetail])
|
||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
|
||||
async def get_editor_chart_list(con_uid: str):
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||
db_name = last_round["param_value"]
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
return Result.succ(json.loads(element["data"]["content"]))
|
||||
|
||||
return Result.faild("没有获取到可用的SQL返回结构")
|
||||
chart_list: ChartList = ChartList(round=last_round, db_name=db_name,
|
||||
charts=json.loads(element["data"]["content"]))
|
||||
return Result.succ(chart_list)
|
||||
return Result.faild(msg="Not have charts!")
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart_info(con_uid: str, chart_uid: str):
|
||||
async def get_editor_chart_info(con_uid: str, chart_title: str):
|
||||
logger.info(f"get_editor_sql_rounds:{con_uid}")
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
@@ -136,12 +157,12 @@ async def get_editor_chart_info(con_uid: str, chart_uid: str):
|
||||
db_name = last_round["param_value"]
|
||||
if not db_name:
|
||||
logger.error("this dashboard dialogue version too old, can't support editor!")
|
||||
return Result.faild("this dashboard dialogue version too old, can't support editor!")
|
||||
return Result.faild(msg="this dashboard dialogue version too old, can't support editor!")
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_uid, charts))[0]
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_title, charts))[0]
|
||||
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'],
|
||||
@@ -155,23 +176,27 @@ async def get_editor_chart_info(con_uid: str, chart_uid: str):
|
||||
)
|
||||
|
||||
return Result.succ(detail)
|
||||
return Result.faild("Can't Find Chart Detail Info!")
|
||||
return Result.faild(msg="Can't Find Chart Detail Info!")
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||
async def editor_chart_run(db_name: str, sql: str):
|
||||
logger.info(f"editor_chart_run:{db_name},{sql}")
|
||||
dashboard_data_loader:DashboardDataLoader = DashboardDataLoader()
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
|
||||
field_names,chart_values = dashboard_data_loader.get_chart_values_by_db(db_conn, sql)
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, sql)
|
||||
|
||||
sql_run_data:SqlRunData = SqlRunData(result_info="",
|
||||
run_cost="",
|
||||
colunms= field_names,
|
||||
values= db_conn.query_ex(sql)
|
||||
)
|
||||
return Result.succ(ChartRunData(sql_data=sql_run_data,chart_values=chart_values))
|
||||
start_time = time.time() * 1000
|
||||
colunms, sql_result = db_conn.query_ex(sql)
|
||||
# 计算执行耗时
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result
|
||||
)
|
||||
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values))
|
||||
|
||||
|
||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||
@@ -180,21 +205,40 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
|
||||
history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
edit_round = list(filter(lambda x: x['chat_order'] == chart_edit_context.conv_round, history_messages))[0]
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
|
||||
|
||||
edit_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||
if edit_round:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
|
||||
try:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_title, charts))[0]
|
||||
if chart_edit_context.new_chart_type:
|
||||
find_chart['chart_type'] = chart_edit_context.new_chart_type
|
||||
if chart_edit_context.new_comment:
|
||||
find_chart['chart_desc'] = chart_edit_context.new_comment
|
||||
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn,
|
||||
chart_edit_context.new_sql)
|
||||
find_chart['chart_sql'] = chart_edit_context.new_sql
|
||||
find_chart['values'] = [value.dict() for value in chart_values]
|
||||
find_chart['column_name'] = field_names
|
||||
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
|
||||
|
||||
element["data"]["content"] = json.dumps(view_data, ensure_ascii=False)
|
||||
if element["type"] == "ai":
|
||||
ai_resp: dict = json.loads(element["data"]["content"])
|
||||
edit_item = list(filter(lambda x: x['title'] == chart_edit_context.chart_title, ai_resp))[0]
|
||||
|
||||
edit_item["sql"] = chart_edit_context.new_sql
|
||||
edit_item["showcase"] = chart_edit_context.new_chart_type
|
||||
edit_item["thoughts"] = chart_edit_context.new_comment
|
||||
element["data"]["content"] = json.dumps(ai_resp, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"edit chart exception!{str(e)}" ,e)
|
||||
return Result.faild(msg=f"Edit chart exception!{str(e)}")
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.faild("Edit Faild!")
|
||||
return Result.faild(msg="Edit Faild!")
|
||||
|
||||
@@ -21,6 +21,12 @@ class ChatDbRounds(BaseModel):
|
||||
round_name: str
|
||||
|
||||
|
||||
class ChartList(BaseModel):
|
||||
round: int
|
||||
db_name: str
|
||||
charts: List
|
||||
|
||||
|
||||
class ChartDetail(BaseModel):
|
||||
chart_uid: str
|
||||
chart_type: str
|
||||
@@ -34,24 +40,25 @@ class ChartDetail(BaseModel):
|
||||
|
||||
class ChatChartEditContext(BaseModel):
|
||||
conv_uid: str
|
||||
conv_round: int
|
||||
chart_uid: str
|
||||
|
||||
chart_title: str
|
||||
db_name: str
|
||||
old_sql: str
|
||||
new_sql: str
|
||||
comment: str
|
||||
gmt_create: int
|
||||
|
||||
new_view_info: str
|
||||
new_chart_type: str
|
||||
new_sql: str
|
||||
new_comment: str
|
||||
gmt_create: int
|
||||
|
||||
|
||||
class ChatSqlEditContext(BaseModel):
|
||||
conv_uid: str
|
||||
db_name: str
|
||||
conv_round: int
|
||||
|
||||
old_sql: str
|
||||
new_sql: str
|
||||
comment: str
|
||||
old_speak: str
|
||||
gmt_create: int
|
||||
|
||||
new_sql: str
|
||||
new_speak: str
|
||||
new_view_info: str
|
||||
|
||||
Reference in New Issue
Block a user