mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 12:42:34 +00:00
feat(editor): editor api devlop
editor api devlop part 1
This commit is contained in:
parent
de3bcd1f68
commit
a95d292bf9
@ -250,17 +250,17 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
"""Format the error message"""
|
"""Format the error message"""
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
def __write(self, session, write_sql):
|
def __write(self, write_sql):
|
||||||
print(f"Write[{write_sql}]")
|
print(f"Write[{write_sql}]")
|
||||||
db_cache = self._engine.url.database
|
db_cache = self._engine.url.database
|
||||||
result = session.execute(text(write_sql))
|
result = self.session.execute(text(write_sql))
|
||||||
session.commit()
|
self.session.commit()
|
||||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||||
session.execute(text(f"use `{db_cache}`"))
|
self.session.execute(text(f"use `{db_cache}`"))
|
||||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|
||||||
def __query(self, session, query, fetch: str = "all"):
|
def __query(self,query, fetch: str = "all"):
|
||||||
"""
|
"""
|
||||||
only for query
|
only for query
|
||||||
Args:
|
Args:
|
||||||
@ -274,7 +274,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
print(f"Query[{query}]")
|
print(f"Query[{query}]")
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
cursor = session.execute(text(query))
|
cursor = self.session.execute(text(query))
|
||||||
if cursor.returns_rows:
|
if cursor.returns_rows:
|
||||||
if fetch == "all":
|
if fetch == "all":
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
@ -288,7 +288,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
result.insert(0, field_names)
|
result.insert(0, field_names)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def query_ex(self, session, query, fetch: str = "all"):
|
def query_ex(self, query, fetch: str = "all"):
|
||||||
"""
|
"""
|
||||||
only for query
|
only for query
|
||||||
Args:
|
Args:
|
||||||
@ -300,7 +300,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
print(f"Query[{query}]")
|
print(f"Query[{query}]")
|
||||||
if not query:
|
if not query:
|
||||||
return []
|
return []
|
||||||
cursor = session.execute(text(query))
|
cursor = self.session.execute(text(query))
|
||||||
if cursor.returns_rows:
|
if cursor.returns_rows:
|
||||||
if fetch == "all":
|
if fetch == "all":
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
@ -313,7 +313,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
result = list(result)
|
result = list(result)
|
||||||
return field_names, result
|
return field_names, result
|
||||||
|
|
||||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
def run(self, command: str, fetch: str = "all") -> List:
|
||||||
"""Execute a SQL command and return a string representing the results."""
|
"""Execute a SQL command and return a string representing the results."""
|
||||||
print("SQL:" + command)
|
print("SQL:" + command)
|
||||||
if not command:
|
if not command:
|
||||||
@ -321,17 +321,17 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||||
if ttype == sqlparse.tokens.DML:
|
if ttype == sqlparse.tokens.DML:
|
||||||
if sql_type == "SELECT":
|
if sql_type == "SELECT":
|
||||||
return self.__query(session, command, fetch)
|
return self.__query( command, fetch)
|
||||||
else:
|
else:
|
||||||
self.__write(session, command)
|
self.__write( command)
|
||||||
select_sql = self.convert_sql_write_to_select(command)
|
select_sql = self.convert_sql_write_to_select(command)
|
||||||
print(f"write result query:{select_sql}")
|
print(f"write result query:{select_sql}")
|
||||||
return self.__query(session, select_sql)
|
return self.__query( select_sql)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"DDL execution determines whether to enable through configuration ")
|
print(f"DDL execution determines whether to enable through configuration ")
|
||||||
cursor = session.execute(text(command))
|
cursor = self.session.execute(text(command))
|
||||||
session.commit()
|
self.session.commit()
|
||||||
if cursor.returns_rows:
|
if cursor.returns_rows:
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
field_names = tuple(i[0:] for i in cursor.keys())
|
field_names = tuple(i[0:] for i in cursor.keys())
|
||||||
@ -339,10 +339,10 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
result.insert(0, field_names)
|
result.insert(0, field_names)
|
||||||
print("DDL Result:" + str(result))
|
print("DDL Result:" + str(result))
|
||||||
if not result:
|
if not result:
|
||||||
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
|
return self.__query( f"SHOW COLUMNS FROM {table_name}")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
|
return self.__query( f"SHOW COLUMNS FROM {table_name}")
|
||||||
|
|
||||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||||
"""Execute a SQL command and return a string representing the results.
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
@ -94,6 +94,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
cursor.commit()
|
cursor.commit()
|
||||||
self.connect.commit()
|
self.connect.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def update(self, messages:List[OnceConversation]) -> None:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||||
|
[json.dumps(messages, ensure_ascii=False), self.chat_seesion_id],
|
||||||
|
)
|
||||||
|
cursor.commit()
|
||||||
|
self.connect.commit()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@ -134,6 +144,24 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def conv_info(self, conv_uid: str = None) -> None:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM chat_history where conv_uid=? ",
|
||||||
|
[conv_uid],
|
||||||
|
)
|
||||||
|
# 获取查询结果字段名
|
||||||
|
fields = [field[0] for field in cursor.description]
|
||||||
|
|
||||||
|
for row in cursor.fetchone():
|
||||||
|
row_dict = {}
|
||||||
|
for i, field in enumerate(fields):
|
||||||
|
row_dict[field] = row[i]
|
||||||
|
return row_dict
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_messages(self) -> List[OnceConversation]:
|
def get_messages(self) -> List[OnceConversation]:
|
||||||
cursor = self.connect.cursor()
|
cursor = self.connect.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
|
@ -1,17 +1,12 @@
|
|||||||
import os
|
import json
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Request,
|
|
||||||
Body,
|
Body,
|
||||||
BackgroundTasks,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.server.knowledge.service import KnowledgeService
|
|
||||||
|
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
@ -19,9 +14,6 @@ from pilot.utils import build_logger
|
|||||||
|
|
||||||
from pilot.openapi.api_view_model import (
|
from pilot.openapi.api_view_model import (
|
||||||
Result,
|
Result,
|
||||||
ConversationVo,
|
|
||||||
MessageVo,
|
|
||||||
ChatSceneVo,
|
|
||||||
)
|
)
|
||||||
from pilot.openapi.editor_view_model import (
|
from pilot.openapi.editor_view_model import (
|
||||||
ChatDbRounds,
|
ChatDbRounds,
|
||||||
@ -31,9 +23,11 @@ from pilot.openapi.editor_view_model import (
|
|||||||
DbTable
|
DbTable
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData
|
from pilot.openapi.api_v1.editor.sql_editor import DataNode,ChartRunData,SqlRunData
|
||||||
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
|
from pilot.scene.message import OnceConversation
|
||||||
|
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||||
|
|
||||||
from pilot.scene.chat_db.auto_execute.out_parser import SqlAction
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -43,46 +37,164 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
|||||||
|
|
||||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||||
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
|
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
|
||||||
return Result.succ(None)
|
logger.info("get_editor_tables:{},{},{},{}", db_name, page_index, page_size, search_str)
|
||||||
|
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
|
tables = db_conn.get_table_names()
|
||||||
|
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
|
||||||
|
for table in tables:
|
||||||
|
table_node: DataNode = DataNode(title=table, key=table, type="table")
|
||||||
|
db_node.children.append(table_node)
|
||||||
|
fields = db_conn.get_fields("transaction_order")
|
||||||
|
for field in fields:
|
||||||
|
table_node.children.append(
|
||||||
|
DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3],
|
||||||
|
comment=field[-1]))
|
||||||
|
|
||||||
|
return Result.succ(db_node)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||||
async def get_editor_sql_rounds(con_uid: str):
|
async def get_editor_sql_rounds(con_uid: str):
|
||||||
return Result.succ(None)
|
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||||
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
|
if history_messages:
|
||||||
|
result: List = []
|
||||||
|
for once in history_messages:
|
||||||
|
round_name: str = ""
|
||||||
|
for element in once["messages"]:
|
||||||
|
if element["type"] == "human":
|
||||||
|
round_name = element["data"]["content"]
|
||||||
|
if once.get("param_value"):
|
||||||
|
round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"],
|
||||||
|
round_name=round_name)
|
||||||
|
result.append(round)
|
||||||
|
return Result.succ(result)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/editor/sql", response_model=Result[SqlAction])
|
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||||
async def get_editor_sql(con_uid: str, round: int):
|
async def get_editor_sql(con_uid: str, round: int):
|
||||||
return Result.succ(None)
|
logger.info("get_editor_sql:{},{}", con_uid, round)
|
||||||
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
@router.get("/v1/editor/chart/details", response_model=Result[ChartDetail])
|
if history_messages:
|
||||||
async def get_editor_sql_rounds(con_uid: str):
|
for once in history_messages:
|
||||||
return Result.succ(None)
|
if int(once["chat_order"]) == round:
|
||||||
|
for element in once["messages"]:
|
||||||
|
if element["type"] == "ai":
|
||||||
@router.get("/v1/editor/chart", response_model=Result[ChartDetail])
|
return Result.succ(json.loads(element["data"]["content"]))
|
||||||
async def get_editor_chart(con_uid: str, chart_uid: str):
|
return Result.faild("没有获取到可用的SQL返回结构")
|
||||||
return Result.succ(None)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/editor/sql/run", response_model=Result[List[dict]])
|
@router.post("/v1/editor/sql/run", response_model=Result[List[dict]])
|
||||||
async def get_editor_chart(db_name: str, sql: str):
|
async def editor_sql_run(db_name: str, sql: str):
|
||||||
return Result.succ(None)
|
logger.info("get_editor_sql_run:{},{}", db_name, sql)
|
||||||
|
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
|
return Result.succ(conn.run(sql))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartData])
|
@router.post("/v1/sql/editor/submit", response_model=Result)
|
||||||
async def get_editor_chart(db_name: str, sql: str):
|
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||||
|
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||||
|
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
|
||||||
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
|
if history_messages:
|
||||||
|
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0]
|
||||||
|
if edit_round:
|
||||||
|
for element in edit_round["messages"]:
|
||||||
|
if element["type"] == "ai":
|
||||||
|
element["data"]["content"]=""
|
||||||
|
if element["type"] == "view":
|
||||||
|
element["data"]["content"]=""
|
||||||
|
history_mem.update(history_messages)
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
return Result.faild("Edit Faild!")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/editor/chart/list", response_model=Result[ChartDetail])
|
||||||
|
async def get_editor_chart_list(con_uid: str):
|
||||||
|
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||||
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
|
if history_messages:
|
||||||
|
last_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||||
|
for element in last_round["messages"]:
|
||||||
|
if element["type"] == "ai":
|
||||||
|
return Result.succ(json.loads(element["data"]["content"]))
|
||||||
|
|
||||||
|
return Result.faild("没有获取到可用的SQL返回结构")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||||
|
async def get_editor_chart_info(con_uid: str, chart_uid: str):
|
||||||
|
logger.info(f"get_editor_sql_rounds:{con_uid}")
|
||||||
|
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||||
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
|
if history_messages:
|
||||||
|
last_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||||
|
db_name = last_round["param_value"]
|
||||||
|
if not db_name:
|
||||||
|
logger.error("this dashboard dialogue version too old, can't support editor!")
|
||||||
|
return Result.faild("this dashboard dialogue version too old, can't support editor!")
|
||||||
|
for element in last_round["messages"]:
|
||||||
|
if element["type"] == "view":
|
||||||
|
view_data: dict = json.loads(element["data"]["content"]);
|
||||||
|
charts: List = view_data.get("charts")
|
||||||
|
find_chart = list(filter(lambda x: x['chart_name'] == chart_uid, charts))[0]
|
||||||
|
|
||||||
|
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
|
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'],
|
||||||
|
chart_type=find_chart['chart_type'],
|
||||||
|
chart_desc=find_chart['chart_desc'],
|
||||||
|
chart_sql=find_chart['chart_sql'],
|
||||||
|
db_name=db_name,
|
||||||
|
chart_name=find_chart['chart_name'],
|
||||||
|
chart_value=find_chart['values'],
|
||||||
|
table_value=conn.run(find_chart['chart_sql'])
|
||||||
|
)
|
||||||
|
|
||||||
|
return Result.succ(detail)
|
||||||
|
return Result.faild("Can't Find Chart Detail Info!")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||||
|
async def editor_chart_run(db_name: str, sql: str):
|
||||||
|
logger.info(f"editor_chart_run:{db_name},{sql}")
|
||||||
|
dashboard_data_loader:DashboardDataLoader = DashboardDataLoader()
|
||||||
|
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
|
|
||||||
|
field_names,chart_values = dashboard_data_loader.get_chart_values_by_db(db_conn, sql)
|
||||||
|
|
||||||
|
sql_run_data:SqlRunData = SqlRunData(result_info="",
|
||||||
|
run_cost="",
|
||||||
|
colunms= field_names,
|
||||||
|
values= db_conn.query_ex(sql)
|
||||||
|
)
|
||||||
|
return Result.succ(ChartRunData(sql_data=sql_run_data,chart_values=chart_values))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||||
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||||
return Result.succ(None)
|
logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}")
|
||||||
|
history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid)
|
||||||
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
|
if history_messages:
|
||||||
|
edit_round = list(filter(lambda x: x['chat_order'] == chart_edit_context.conv_round, history_messages))[0]
|
||||||
|
if edit_round:
|
||||||
|
for element in edit_round["messages"]:
|
||||||
|
if element["type"] == "ai":
|
||||||
|
view_data: dict = json.loads(element["data"]["content"]);
|
||||||
|
charts: List = view_data.get("charts")
|
||||||
|
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/sql/editor/submit", response_model=Result[bool])
|
if element["type"] == "view":
|
||||||
async def chart_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
view_data: dict = json.loads(element["data"]["content"]);
|
||||||
|
charts: List = view_data.get("charts")
|
||||||
|
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
|
||||||
|
|
||||||
|
|
||||||
|
history_mem.update(history_messages)
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
return Result.faild("Edit Faild!")
|
||||||
|
26
pilot/openapi/api_v1/editor/sql_editor.py
Normal file
26
pilot/openapi/api_v1/editor/sql_editor.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from typing import List
|
||||||
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||||
|
|
||||||
|
|
||||||
|
class DataNode(BaseModel):
|
||||||
|
title: str
|
||||||
|
key: str
|
||||||
|
|
||||||
|
type: str = ""
|
||||||
|
default_value: str = None
|
||||||
|
can_null: str = 'YES'
|
||||||
|
comment: str = None
|
||||||
|
children: List = []
|
||||||
|
|
||||||
|
|
||||||
|
class SqlRunData(BaseModel):
|
||||||
|
result_info: str
|
||||||
|
run_cost: str
|
||||||
|
colunms: List[str]
|
||||||
|
values: List
|
||||||
|
|
||||||
|
|
||||||
|
class ChartRunData(BaseModel):
|
||||||
|
sql_data: SqlRunData
|
||||||
|
chart_values: List[ValueItem]
|
@ -24,10 +24,12 @@ class ChatDbRounds(BaseModel):
|
|||||||
class ChartDetail(BaseModel):
|
class ChartDetail(BaseModel):
|
||||||
chart_uid: str
|
chart_uid: str
|
||||||
chart_type: str
|
chart_type: str
|
||||||
|
chart_desc: str
|
||||||
|
chart_sql: str
|
||||||
db_name: str
|
db_name: str
|
||||||
chart_name: str
|
chart_name: str
|
||||||
chart_value: str
|
chart_value: Any
|
||||||
chat_round: int # defualt last round
|
table_value: Any
|
||||||
|
|
||||||
|
|
||||||
class ChatChartEditContext(BaseModel):
|
class ChatChartEditContext(BaseModel):
|
||||||
|
@ -64,6 +64,7 @@ class BaseChat(ABC):
|
|||||||
chat_mode,
|
chat_mode,
|
||||||
chat_session_id,
|
chat_session_id,
|
||||||
current_user_input,
|
current_user_input,
|
||||||
|
select_param: Any = None
|
||||||
):
|
):
|
||||||
self.chat_session_id = chat_session_id
|
self.chat_session_id = chat_session_id
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
@ -87,6 +88,9 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
||||||
|
if select_param:
|
||||||
|
self.current_message.param_type = chat_mode.param_types()[0]
|
||||||
|
self.current_message.param_value = select_param
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -111,6 +115,24 @@ class BaseChat(ABC):
|
|||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
return prompt_response
|
return prompt_response
|
||||||
|
|
||||||
|
def get_llm_speak(self, prompt_define_response):
|
||||||
|
if hasattr(prompt_define_response, "thoughts"):
|
||||||
|
if isinstance(prompt_define_response.thoughts, dict):
|
||||||
|
if "speak" in prompt_define_response.thoughts:
|
||||||
|
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||||
|
else:
|
||||||
|
speak_to_user = str(prompt_define_response.thoughts)
|
||||||
|
else:
|
||||||
|
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||||
|
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||||
|
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||||
|
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
||||||
|
else:
|
||||||
|
speak_to_user = prompt_define_response.thoughts
|
||||||
|
else:
|
||||||
|
speak_to_user = prompt_define_response
|
||||||
|
return speak_to_user
|
||||||
|
|
||||||
def __call_base(self):
|
def __call_base(self):
|
||||||
input_values = self.generate_input_values()
|
input_values = self.generate_input_values()
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
@ -209,26 +231,13 @@ class BaseChat(ABC):
|
|||||||
ai_response_text
|
ai_response_text
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
### sql run
|
||||||
result = self.do_action(prompt_define_response)
|
result = self.do_action(prompt_define_response)
|
||||||
|
|
||||||
if hasattr(prompt_define_response, "thoughts"):
|
### llm speaker
|
||||||
if isinstance(prompt_define_response.thoughts, dict):
|
speak_to_user = self.get_llm_speak(prompt_define_response)
|
||||||
if "speak" in prompt_define_response.thoughts:
|
|
||||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
|
||||||
else:
|
|
||||||
speak_to_user = str(prompt_define_response.thoughts)
|
|
||||||
else:
|
|
||||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
|
||||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
|
||||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
|
||||||
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
|
||||||
else:
|
|
||||||
speak_to_user = prompt_define_response.thoughts
|
|
||||||
else:
|
|
||||||
speak_to_user = prompt_define_response
|
|
||||||
view_message = self.prompt_template.output_parser.parse_view_response(
|
|
||||||
speak_to_user, result
|
|
||||||
)
|
|
||||||
self.current_message.add_view_message(view_message)
|
self.current_message.add_view_message(view_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
@ -95,6 +95,7 @@ class ModelMessageRoleType:
|
|||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
HUMAN = "human"
|
HUMAN = "human"
|
||||||
AI = "ai"
|
AI = "ai"
|
||||||
|
VIEW = "view"
|
||||||
|
|
||||||
|
|
||||||
class Generation(BaseModel):
|
class Generation(BaseModel):
|
||||||
|
@ -1,26 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, NamedTuple, List
|
from typing import List
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
from pilot.scene.base_message import (
|
|
||||||
HumanMessage,
|
|
||||||
ViewMessage,
|
|
||||||
)
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.common.markdown_text import (
|
|
||||||
generate_htm_table,
|
|
||||||
)
|
|
||||||
from pilot.scene.chat_dashboard.prompt import prompt
|
|
||||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||||
ChartData,
|
ChartData,
|
||||||
ReportData,
|
ReportData,
|
||||||
ValueItem,
|
|
||||||
)
|
)
|
||||||
|
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -85,42 +75,10 @@ class ChatDashboard(BaseChat):
|
|||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
|
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
|
||||||
chart_datas: List[ChartData] = []
|
chart_datas: List[ChartData] = []
|
||||||
|
dashboard_data_loader = DashboardDataLoader()
|
||||||
for chart_item in prompt_response:
|
for chart_item in prompt_response:
|
||||||
try:
|
try:
|
||||||
field_names, datas = self.database.query_ex(
|
field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.db_connect, chart_item.sql)
|
||||||
self.db_connect, chart_item.sql
|
|
||||||
)
|
|
||||||
values: List[ValueItem] = []
|
|
||||||
data_map = {}
|
|
||||||
field_map = {}
|
|
||||||
index = 0
|
|
||||||
for field_name in field_names:
|
|
||||||
data_map.update({f"{field_name}": [row[index] for row in datas]})
|
|
||||||
index += 1
|
|
||||||
if not data_map[field_name]:
|
|
||||||
field_map.update({f"{field_name}": False})
|
|
||||||
else:
|
|
||||||
field_map.update(
|
|
||||||
{
|
|
||||||
f"{field_name}": all(
|
|
||||||
isinstance(item, (int, float, Decimal))
|
|
||||||
for item in data_map[field_name]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name in field_names[1:]:
|
|
||||||
if not field_map[field_name]:
|
|
||||||
print("more than 2 non-numeric column")
|
|
||||||
else:
|
|
||||||
for data in datas:
|
|
||||||
value_item = ValueItem(
|
|
||||||
name=data[0],
|
|
||||||
type=field_name,
|
|
||||||
value=data[field_names.index(field_name)],
|
|
||||||
)
|
|
||||||
values.append(value_item)
|
|
||||||
|
|
||||||
chart_datas.append(
|
chart_datas.append(
|
||||||
ChartData(
|
ChartData(
|
||||||
chart_uid=str(uuid.uuid1()),
|
chart_uid=str(uuid.uuid1()),
|
||||||
@ -135,7 +93,6 @@ class ChatDashboard(BaseChat):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO 修复流程
|
# TODO 修复流程
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
return ReportData(
|
return ReportData(
|
||||||
conv_uid=self.chat_session_id,
|
conv_uid=self.chat_session_id,
|
||||||
template_name=self.report_name,
|
template_name=self.report_name,
|
||||||
|
57
pilot/scene/chat_dashboard/data_loader.py
Normal file
57
pilot/scene/chat_dashboard/data_loader.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from typing import List
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
|
||||||
|
|
||||||
|
|
||||||
|
class DashboardDataLoader:
|
||||||
|
|
||||||
|
def get_chart_values_by_conn(self, db_conn, chart_sql: str) :
|
||||||
|
logger.info(f"get_chart_values_by_conn:{chart_sql}")
|
||||||
|
try:
|
||||||
|
field_names, datas = db_conn.query_ex(chart_sql)
|
||||||
|
values: List[ValueItem] = []
|
||||||
|
data_map = {}
|
||||||
|
field_map = {}
|
||||||
|
index = 0
|
||||||
|
for field_name in field_names:
|
||||||
|
data_map.update({f"{field_name}": [row[index] for row in datas]})
|
||||||
|
index += 1
|
||||||
|
if not data_map[field_name]:
|
||||||
|
field_map.update({f"{field_name}": False})
|
||||||
|
else:
|
||||||
|
field_map.update(
|
||||||
|
{
|
||||||
|
f"{field_name}": all(
|
||||||
|
isinstance(item, (int, float, Decimal))
|
||||||
|
for item in data_map[field_name]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name in field_names[1:]:
|
||||||
|
if not field_map[field_name]:
|
||||||
|
logger.info("More than 2 non-numeric column:" + field_name)
|
||||||
|
else:
|
||||||
|
for data in datas:
|
||||||
|
value_item = ValueItem(
|
||||||
|
name=data[0],
|
||||||
|
type=field_name,
|
||||||
|
value=data[field_names.index(field_name)],
|
||||||
|
)
|
||||||
|
values.append(value_item)
|
||||||
|
return field_names, values
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Prepare Chart Data Faild!" + str(e))
|
||||||
|
raise ValueError("Prepare Chart Data Faild!")
|
||||||
|
|
||||||
|
def get_chart_values_by_db(self, db_name: str, chart_sql: str) :
|
||||||
|
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
|
||||||
|
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
|
return self.get_chart_values_by_conn(db_conn, chart_sql)
|
@ -22,11 +22,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, chat_session_id, db_name, user_input):
|
def __init__(self, chat_session_id, db_name, user_input):
|
||||||
|
chat_mode = ChatScene.ChatWithDbExecute
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatWithDbExecute,
|
chat_mode=chat_mode,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
|
select_param=db_name,
|
||||||
)
|
)
|
||||||
if not db_name:
|
if not db_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -35,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
self.db_connect = self.database.session
|
self.db_connect = self.database.session
|
||||||
self.top_k: int = 5
|
self.top_k: int = 200
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
try:
|
try:
|
||||||
|
8
pilot/scene/chat_db/auto_execute/data_loader.py
Normal file
8
pilot/scene/chat_db/auto_execute/data_loader.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class DbDataLoader:
|
||||||
|
|
||||||
|
|
||||||
|
def get_table_view_by_conn(self, db_conn, chart_sql: str):
|
||||||
|
pass
|
@ -25,6 +25,7 @@ class ChatWithDbQA(BaseChat):
|
|||||||
chat_mode=ChatScene.ChatWithDbQA,
|
chat_mode=ChatScene.ChatWithDbQA,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
|
select_param=db_name,
|
||||||
)
|
)
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
if db_name:
|
if db_name:
|
||||||
|
@ -30,6 +30,7 @@ class ChatWithPlugin(BaseChat):
|
|||||||
chat_mode=ChatScene.ChatExecution,
|
chat_mode=ChatScene.ChatExecution,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
|
select_param=plugin_selector,
|
||||||
)
|
)
|
||||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||||
|
@ -33,6 +33,7 @@ class ChatNewKnowledge(BaseChat):
|
|||||||
chat_mode=ChatScene.ChatNewKnowledge,
|
chat_mode=ChatScene.ChatNewKnowledge,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
|
select_param=knowledge_name,
|
||||||
)
|
)
|
||||||
self.knowledge_name = knowledge_name
|
self.knowledge_name = knowledge_name
|
||||||
vector_store_config = {
|
vector_store_config = {
|
||||||
|
@ -24,6 +24,7 @@ class InnerChatDBSummary(BaseChat):
|
|||||||
chat_mode=ChatScene.InnerChatDBSummary,
|
chat_mode=ChatScene.InnerChatDBSummary,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
current_user_input=user_input,
|
current_user_input=user_input,
|
||||||
|
select_param=db_select,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db_input = db_select
|
self.db_input = db_select
|
||||||
|
@ -30,6 +30,8 @@ class OnceConversation:
|
|||||||
self.messages: List[BaseMessage] = []
|
self.messages: List[BaseMessage] = []
|
||||||
self.start_date: str = ""
|
self.start_date: str = ""
|
||||||
self.chat_order: int = 0
|
self.chat_order: int = 0
|
||||||
|
self.param_type: str = ""
|
||||||
|
self.param_value: str = ""
|
||||||
self.cost: int = 0
|
self.cost: int = 0
|
||||||
self.tokens: int = 0
|
self.tokens: int = 0
|
||||||
|
|
||||||
@ -114,9 +116,13 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
|||||||
"cost": once.cost if once.cost else 0,
|
"cost": once.cost if once.cost else 0,
|
||||||
"tokens": once.tokens if once.tokens else 0,
|
"tokens": once.tokens if once.tokens else 0,
|
||||||
"messages": messages_to_dict(once.messages),
|
"messages": messages_to_dict(once.messages),
|
||||||
|
"param_type": once.param_type,
|
||||||
|
"param_value": once.param_value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
||||||
return [_conversation_to_dic(m) for m in conversations]
|
return [_conversation_to_dic(m) for m in conversations]
|
||||||
|
|
||||||
@ -128,6 +134,8 @@ def conversation_from_dict(once: dict) -> OnceConversation:
|
|||||||
conversation.tokens = once.get("tokens", 0)
|
conversation.tokens = once.get("tokens", 0)
|
||||||
conversation.start_date = once.get("start_date", "")
|
conversation.start_date = once.get("start_date", "")
|
||||||
conversation.chat_order = int(once.get("chat_order"))
|
conversation.chat_order = int(once.get("chat_order"))
|
||||||
|
conversation.param_type = once.get("param_type", "")
|
||||||
|
conversation.param_value = once.get("param_value", "")
|
||||||
print(once.get("messages"))
|
print(once.get("messages"))
|
||||||
conversation.messages = messages_from_dict(once.get("messages", []))
|
conversation.messages = messages_from_dict(once.get("messages", []))
|
||||||
return conversation
|
return conversation
|
||||||
|
@ -51,19 +51,15 @@ def build_logger(logger_name, logger_filename):
|
|||||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
# Redirect stdout and stderr to loggers
|
# Redirect stdout and stderr to loggers
|
||||||
stdout_logger = logging.getLogger("stdout")
|
# stdout_logger = logging.getLogger("stdout")
|
||||||
stdout_logger.setLevel(logging.INFO)
|
# stdout_logger.setLevel(logging.INFO)
|
||||||
sl = StreamToLogger(stdout_logger, logging.INFO)
|
# sl_1 = StreamToLogger(stdout_logger, logging.INFO)
|
||||||
sys.stdout = sl
|
# sys.stdout = sl_1
|
||||||
|
#
|
||||||
stderr_logger = logging.getLogger("stderr")
|
# stderr_logger = logging.getLogger("stderr")
|
||||||
stderr_logger.setLevel(logging.ERROR)
|
# stderr_logger.setLevel(logging.ERROR)
|
||||||
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
# sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||||
sys.stderr = sl
|
# sys.stderr = sl
|
||||||
|
|
||||||
# Get logger
|
|
||||||
logger = logging.getLogger(logger_name)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# Add a file handler for all loggers
|
# Add a file handler for all loggers
|
||||||
if handler is None:
|
if handler is None:
|
||||||
@ -78,6 +74,12 @@ def build_logger(logger_name, logger_filename):
|
|||||||
if isinstance(item, logging.Logger):
|
if isinstance(item, logging.Logger):
|
||||||
item.addHandler(handler)
|
item.addHandler(handler)
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
# Get logger
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user