mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
feat(editor): editor api devlop
editor api devlop part 1
This commit is contained in:
parent
de3bcd1f68
commit
a95d292bf9
@ -250,17 +250,17 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def __write(self, session, write_sql):
|
||||
def __write(self, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self._engine.url.database
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
result = self.session.execute(text(write_sql))
|
||||
self.session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
session.execute(text(f"use `{db_cache}`"))
|
||||
self.session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self, session, query, fetch: str = "all"):
|
||||
def __query(self,query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
@ -274,7 +274,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
cursor = self.session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
@ -288,7 +288,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def query_ex(self, session, query, fetch: str = "all"):
|
||||
def query_ex(self, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
@ -300,7 +300,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
cursor = self.session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
@ -313,7 +313,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
if not command:
|
||||
@ -321,17 +321,17 @@ class RDBMSDatabase(BaseConnect):
|
||||
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self.__query(session, command, fetch)
|
||||
return self.__query( command, fetch)
|
||||
else:
|
||||
self.__write(session, command)
|
||||
self.__write( command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query(session, select_sql)
|
||||
return self.__query( select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
cursor = session.execute(text(command))
|
||||
session.commit()
|
||||
cursor = self.session.execute(text(command))
|
||||
self.session.commit()
|
||||
if cursor.returns_rows:
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
@ -339,10 +339,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.__query( f"SHOW COLUMNS FROM {table_name}")
|
||||
return result
|
||||
else:
|
||||
return self.__query(session, f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.__query( f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
@ -94,6 +94,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
|
||||
def update(self, messages:List[OnceConversation]) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(messages, ensure_ascii=False), self.chat_seesion_id],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
@ -134,6 +144,24 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
return []
|
||||
|
||||
def conv_info(self, conv_uid: str = None) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"SELECT * FROM chat_history where conv_uid=? ",
|
||||
[conv_uid],
|
||||
)
|
||||
# 获取查询结果字段名
|
||||
fields = [field[0] for field in cursor.description]
|
||||
|
||||
for row in cursor.fetchone():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
return row_dict
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
|
@ -1,17 +1,12 @@
|
||||
import os
|
||||
|
||||
import json
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Request,
|
||||
Body,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.knowledge.service import KnowledgeService
|
||||
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
@ -19,9 +14,6 @@ from pilot.utils import build_logger
|
||||
|
||||
from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
)
|
||||
from pilot.openapi.editor_view_model import (
|
||||
ChatDbRounds,
|
||||
@ -31,9 +23,11 @@ from pilot.openapi.editor_view_model import (
|
||||
DbTable
|
||||
)
|
||||
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData
|
||||
from pilot.openapi.api_v1.editor.sql_editor import DataNode,ChartRunData,SqlRunData
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import SqlAction
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
@ -42,47 +36,165 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
||||
|
||||
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str= ""):
|
||||
return Result.succ(None)
|
||||
|
||||
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
|
||||
logger.info("get_editor_tables:{},{},{},{}", db_name, page_index, page_size, search_str)
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
tables = db_conn.get_table_names()
|
||||
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
|
||||
for table in tables:
|
||||
table_node: DataNode = DataNode(title=table, key=table, type="table")
|
||||
db_node.children.append(table_node)
|
||||
fields = db_conn.get_fields("transaction_order")
|
||||
for field in fields:
|
||||
table_node.children.append(
|
||||
DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3],
|
||||
comment=field[-1]))
|
||||
|
||||
return Result.succ(db_node)
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||
async def get_editor_sql_rounds(con_uid: str):
|
||||
return Result.succ(None)
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
result: List = []
|
||||
for once in history_messages:
|
||||
round_name: str = ""
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "human":
|
||||
round_name = element["data"]["content"]
|
||||
if once.get("param_value"):
|
||||
round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"],
|
||||
round_name=round_name)
|
||||
result.append(round)
|
||||
return Result.succ(result)
|
||||
|
||||
|
||||
@router.get("/v1/editor/sql", response_model=Result[SqlAction])
|
||||
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||
async def get_editor_sql(con_uid: str, round: int):
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart/details", response_model=Result[ChartDetail])
|
||||
async def get_editor_sql_rounds(con_uid: str):
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart(con_uid: str, chart_uid: str):
|
||||
return Result.succ(None)
|
||||
logger.info("get_editor_sql:{},{}", con_uid, round)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
if int(once["chat_order"]) == round:
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "ai":
|
||||
return Result.succ(json.loads(element["data"]["content"]))
|
||||
return Result.faild("没有获取到可用的SQL返回结构")
|
||||
|
||||
|
||||
@router.post("/v1/editor/sql/run", response_model=Result[List[dict]])
|
||||
async def get_editor_chart(db_name: str, sql: str):
|
||||
return Result.succ(None)
|
||||
async def editor_sql_run(db_name: str, sql: str):
|
||||
logger.info("get_editor_sql_run:{},{}", db_name, sql)
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
return Result.succ(conn.run(sql))
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartData])
|
||||
async def get_editor_chart(db_name: str, sql: str):
|
||||
return Result.succ(None)
|
||||
@router.post("/v1/sql/editor/submit", response_model=Result)
|
||||
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0]
|
||||
if edit_round:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
element["data"]["content"]=""
|
||||
if element["type"] == "view":
|
||||
element["data"]["content"]=""
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.faild("Edit Faild!")
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart_list(con_uid: str):
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
return Result.succ(json.loads(element["data"]["content"]))
|
||||
|
||||
return Result.faild("没有获取到可用的SQL返回结构")
|
||||
|
||||
|
||||
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart_info(con_uid: str, chart_uid: str):
|
||||
logger.info(f"get_editor_sql_rounds:{con_uid}")
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||
db_name = last_round["param_value"]
|
||||
if not db_name:
|
||||
logger.error("this dashboard dialogue version too old, can't support editor!")
|
||||
return Result.faild("this dashboard dialogue version too old, can't support editor!")
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_uid, charts))[0]
|
||||
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'],
|
||||
chart_type=find_chart['chart_type'],
|
||||
chart_desc=find_chart['chart_desc'],
|
||||
chart_sql=find_chart['chart_sql'],
|
||||
db_name=db_name,
|
||||
chart_name=find_chart['chart_name'],
|
||||
chart_value=find_chart['values'],
|
||||
table_value=conn.run(find_chart['chart_sql'])
|
||||
)
|
||||
|
||||
return Result.succ(detail)
|
||||
return Result.faild("Can't Find Chart Detail Info!")
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||
async def editor_chart_run(db_name: str, sql: str):
|
||||
logger.info(f"editor_chart_run:{db_name},{sql}")
|
||||
dashboard_data_loader:DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
|
||||
field_names,chart_values = dashboard_data_loader.get_chart_values_by_db(db_conn, sql)
|
||||
|
||||
sql_run_data:SqlRunData = SqlRunData(result_info="",
|
||||
run_cost="",
|
||||
colunms= field_names,
|
||||
values= db_conn.query_ex(sql)
|
||||
)
|
||||
return Result.succ(ChartRunData(sql_data=sql_run_data,chart_values=chart_values))
|
||||
|
||||
|
||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||
return Result.succ(None)
|
||||
logger.info(f"sql_editor_submit:{chart_edit_context.__dict__}")
|
||||
history_mem = DuckdbHistoryMemory(chart_edit_context.conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
edit_round = list(filter(lambda x: x['chat_order'] == chart_edit_context.conv_round, history_messages))[0]
|
||||
if edit_round:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
|
||||
|
||||
|
||||
@router.post("/v1/sql/editor/submit", response_model=Result[bool])
|
||||
async def chart_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
return Result.succ(None)
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_uid, charts))[0]
|
||||
|
||||
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.faild("Edit Faild!")
|
||||
|
26
pilot/openapi/api_v1/editor/sql_editor.py
Normal file
26
pilot/openapi/api_v1/editor/sql_editor.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||
|
||||
|
||||
class DataNode(BaseModel):
|
||||
title: str
|
||||
key: str
|
||||
|
||||
type: str = ""
|
||||
default_value: str = None
|
||||
can_null: str = 'YES'
|
||||
comment: str = None
|
||||
children: List = []
|
||||
|
||||
|
||||
class SqlRunData(BaseModel):
|
||||
result_info: str
|
||||
run_cost: str
|
||||
colunms: List[str]
|
||||
values: List
|
||||
|
||||
|
||||
class ChartRunData(BaseModel):
|
||||
sql_data: SqlRunData
|
||||
chart_values: List[ValueItem]
|
@ -24,10 +24,12 @@ class ChatDbRounds(BaseModel):
|
||||
class ChartDetail(BaseModel):
|
||||
chart_uid: str
|
||||
chart_type: str
|
||||
chart_desc: str
|
||||
chart_sql: str
|
||||
db_name: str
|
||||
chart_name: str
|
||||
chart_value: str
|
||||
chat_round: int # defualt last round
|
||||
chart_value: Any
|
||||
table_value: Any
|
||||
|
||||
|
||||
class ChatChartEditContext(BaseModel):
|
||||
|
@ -60,10 +60,11 @@ class BaseChat(ABC):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
self,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
select_param: Any = None
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
@ -87,6 +88,9 @@ class BaseChat(ABC):
|
||||
)
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
||||
if select_param:
|
||||
self.current_message.param_type = chat_mode.param_types()[0]
|
||||
self.current_message.param_value = select_param
|
||||
self.current_tokens_used: int = 0
|
||||
|
||||
class Config:
|
||||
@ -111,6 +115,24 @@ class BaseChat(ABC):
|
||||
def do_action(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
def get_llm_speak(self, prompt_define_response):
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
if isinstance(prompt_define_response.thoughts, dict):
|
||||
if "speak" in prompt_define_response.thoughts:
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
else:
|
||||
speak_to_user = str(prompt_define_response.thoughts)
|
||||
else:
|
||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
||||
else:
|
||||
speak_to_user = prompt_define_response.thoughts
|
||||
else:
|
||||
speak_to_user = prompt_define_response
|
||||
return speak_to_user
|
||||
|
||||
def __call_base(self):
|
||||
input_values = self.generate_input_values()
|
||||
### Chat sequence advance
|
||||
@ -209,26 +231,13 @@ class BaseChat(ABC):
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
### sql run
|
||||
result = self.do_action(prompt_define_response)
|
||||
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
if isinstance(prompt_define_response.thoughts, dict):
|
||||
if "speak" in prompt_define_response.thoughts:
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
else:
|
||||
speak_to_user = str(prompt_define_response.thoughts)
|
||||
else:
|
||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
||||
else:
|
||||
speak_to_user = prompt_define_response.thoughts
|
||||
else:
|
||||
speak_to_user = prompt_define_response
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(
|
||||
speak_to_user, result
|
||||
)
|
||||
### llm speaker
|
||||
speak_to_user = self.get_llm_speak(prompt_define_response)
|
||||
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
|
||||
self.current_message.add_view_message(view_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
@ -297,7 +306,7 @@ class BaseChat(ABC):
|
||||
system_messages = []
|
||||
for system_conv in system_convs:
|
||||
system_text += (
|
||||
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||
)
|
||||
system_messages.append(
|
||||
ModelMessage(role=system_conv.type, content=system_conv.content)
|
||||
@ -309,7 +318,7 @@ class BaseChat(ABC):
|
||||
user_messages = []
|
||||
if user_conv:
|
||||
user_text = (
|
||||
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||
)
|
||||
user_messages.append(
|
||||
ModelMessage(role=user_conv.type, content=user_conv.content)
|
||||
@ -331,10 +340,10 @@ class BaseChat(ABC):
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
example_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
example_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
@ -358,10 +367,10 @@ class BaseChat(ABC):
|
||||
message_type = first_message["type"]
|
||||
message_content = first_message["data"]["content"]
|
||||
history_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
@ -377,10 +386,10 @@ class BaseChat(ABC):
|
||||
message_type = round_message["type"]
|
||||
message_content = round_message["data"]["content"]
|
||||
history_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
@ -398,10 +407,10 @@ class BaseChat(ABC):
|
||||
message_type = message["type"]
|
||||
message_content = message["data"]["content"]
|
||||
history_text += (
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
message_type
|
||||
+ ":"
|
||||
+ message_content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
|
@ -95,6 +95,7 @@ class ModelMessageRoleType:
|
||||
SYSTEM = "system"
|
||||
HUMAN = "human"
|
||||
AI = "ai"
|
||||
VIEW = "view"
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
@ -166,7 +167,7 @@ def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
|
||||
|
||||
def _parse_model_messages(
|
||||
messages: List[ModelMessage],
|
||||
messages: List[ModelMessage],
|
||||
) -> Tuple[str, List[str], List[List[str, str]]]:
|
||||
""" "
|
||||
Parameters:
|
||||
|
@ -1,26 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, NamedTuple, List
|
||||
from decimal import Decimal
|
||||
from typing import List
|
||||
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_dashboard.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
ChartData,
|
||||
ReportData,
|
||||
ValueItem,
|
||||
)
|
||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -85,42 +75,10 @@ class ChatDashboard(BaseChat):
|
||||
def do_action(self, prompt_response):
|
||||
### TODO 记录整体信息,处理成功的,和未成功的分开记录处理
|
||||
chart_datas: List[ChartData] = []
|
||||
dashboard_data_loader = DashboardDataLoader()
|
||||
for chart_item in prompt_response:
|
||||
try:
|
||||
field_names, datas = self.database.query_ex(
|
||||
self.db_connect, chart_item.sql
|
||||
)
|
||||
values: List[ValueItem] = []
|
||||
data_map = {}
|
||||
field_map = {}
|
||||
index = 0
|
||||
for field_name in field_names:
|
||||
data_map.update({f"{field_name}": [row[index] for row in datas]})
|
||||
index += 1
|
||||
if not data_map[field_name]:
|
||||
field_map.update({f"{field_name}": False})
|
||||
else:
|
||||
field_map.update(
|
||||
{
|
||||
f"{field_name}": all(
|
||||
isinstance(item, (int, float, Decimal))
|
||||
for item in data_map[field_name]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
for field_name in field_names[1:]:
|
||||
if not field_map[field_name]:
|
||||
print("more than 2 non-numeric column")
|
||||
else:
|
||||
for data in datas:
|
||||
value_item = ValueItem(
|
||||
name=data[0],
|
||||
type=field_name,
|
||||
value=data[field_names.index(field_name)],
|
||||
)
|
||||
values.append(value_item)
|
||||
|
||||
field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.db_connect, chart_item.sql)
|
||||
chart_datas.append(
|
||||
ChartData(
|
||||
chart_uid=str(uuid.uuid1()),
|
||||
@ -135,7 +93,6 @@ class ChatDashboard(BaseChat):
|
||||
except Exception as e:
|
||||
# TODO 修复流程
|
||||
print(str(e))
|
||||
|
||||
return ReportData(
|
||||
conv_uid=self.chat_session_id,
|
||||
template_name=self.report_name,
|
||||
|
57
pilot/scene/chat_dashboard/data_loader.py
Normal file
57
pilot/scene/chat_dashboard/data_loader.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import List
|
||||
from decimal import Decimal
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
|
||||
|
||||
|
||||
class DashboardDataLoader:
|
||||
|
||||
def get_chart_values_by_conn(self, db_conn, chart_sql: str) :
|
||||
logger.info(f"get_chart_values_by_conn:{chart_sql}")
|
||||
try:
|
||||
field_names, datas = db_conn.query_ex(chart_sql)
|
||||
values: List[ValueItem] = []
|
||||
data_map = {}
|
||||
field_map = {}
|
||||
index = 0
|
||||
for field_name in field_names:
|
||||
data_map.update({f"{field_name}": [row[index] for row in datas]})
|
||||
index += 1
|
||||
if not data_map[field_name]:
|
||||
field_map.update({f"{field_name}": False})
|
||||
else:
|
||||
field_map.update(
|
||||
{
|
||||
f"{field_name}": all(
|
||||
isinstance(item, (int, float, Decimal))
|
||||
for item in data_map[field_name]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
for field_name in field_names[1:]:
|
||||
if not field_map[field_name]:
|
||||
logger.info("More than 2 non-numeric column:" + field_name)
|
||||
else:
|
||||
for data in datas:
|
||||
value_item = ValueItem(
|
||||
name=data[0],
|
||||
type=field_name,
|
||||
value=data[field_names.index(field_name)],
|
||||
)
|
||||
values.append(value_item)
|
||||
return field_names, values
|
||||
except Exception as e:
|
||||
logger.debug("Prepare Chart Data Faild!" + str(e))
|
||||
raise ValueError("Prepare Chart Data Faild!")
|
||||
|
||||
def get_chart_values_by_db(self, db_name: str, chart_sql: str) :
|
||||
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
return self.get_chart_values_by_conn(db_conn, chart_sql)
|
@ -22,11 +22,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
chat_mode = ChatScene.ChatWithDbExecute
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=db_name,
|
||||
)
|
||||
if not db_name:
|
||||
raise ValueError(
|
||||
@ -35,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
self.db_name = db_name
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
self.db_connect = self.database.session
|
||||
self.top_k: int = 5
|
||||
self.top_k: int = 200
|
||||
|
||||
def generate_input_values(self):
|
||||
try:
|
||||
|
8
pilot/scene/chat_db/auto_execute/data_loader.py
Normal file
8
pilot/scene/chat_db/auto_execute/data_loader.py
Normal file
@ -0,0 +1,8 @@
|
||||
import json
|
||||
|
||||
|
||||
class DbDataLoader:
|
||||
|
||||
|
||||
def get_table_view_by_conn(self, db_conn, chart_sql: str):
|
||||
pass
|
@ -25,6 +25,7 @@ class ChatWithDbQA(BaseChat):
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=db_name,
|
||||
)
|
||||
self.db_name = db_name
|
||||
if db_name:
|
||||
|
@ -30,6 +30,7 @@ class ChatWithPlugin(BaseChat):
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=plugin_selector,
|
||||
)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
|
@ -33,6 +33,7 @@ class ChatNewKnowledge(BaseChat):
|
||||
chat_mode=ChatScene.ChatNewKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=knowledge_name,
|
||||
)
|
||||
self.knowledge_name = knowledge_name
|
||||
vector_store_config = {
|
||||
|
@ -24,6 +24,7 @@ class InnerChatDBSummary(BaseChat):
|
||||
chat_mode=ChatScene.InnerChatDBSummary,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=db_select,
|
||||
)
|
||||
|
||||
self.db_input = db_select
|
||||
|
@ -30,6 +30,8 @@ class OnceConversation:
|
||||
self.messages: List[BaseMessage] = []
|
||||
self.start_date: str = ""
|
||||
self.chat_order: int = 0
|
||||
self.param_type: str = ""
|
||||
self.param_value: str = ""
|
||||
self.cost: int = 0
|
||||
self.tokens: int = 0
|
||||
|
||||
@ -114,9 +116,13 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
"cost": once.cost if once.cost else 0,
|
||||
"tokens": once.tokens if once.tokens else 0,
|
||||
"messages": messages_to_dict(once.messages),
|
||||
"param_type": once.param_type,
|
||||
"param_value": once.param_value
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
||||
return [_conversation_to_dic(m) for m in conversations]
|
||||
|
||||
@ -128,6 +134,8 @@ def conversation_from_dict(once: dict) -> OnceConversation:
|
||||
conversation.tokens = once.get("tokens", 0)
|
||||
conversation.start_date = once.get("start_date", "")
|
||||
conversation.chat_order = int(once.get("chat_order"))
|
||||
conversation.param_type = once.get("param_type", "")
|
||||
conversation.param_value = once.get("param_value", "")
|
||||
print(once.get("messages"))
|
||||
conversation.messages = messages_from_dict(once.get("messages", []))
|
||||
return conversation
|
||||
|
@ -51,19 +51,15 @@ def build_logger(logger_name, logger_filename):
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
stdout_logger = logging.getLogger("stdout")
|
||||
stdout_logger.setLevel(logging.INFO)
|
||||
sl = StreamToLogger(stdout_logger, logging.INFO)
|
||||
sys.stdout = sl
|
||||
|
||||
stderr_logger = logging.getLogger("stderr")
|
||||
stderr_logger.setLevel(logging.ERROR)
|
||||
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
sys.stderr = sl
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
# stdout_logger = logging.getLogger("stdout")
|
||||
# stdout_logger.setLevel(logging.INFO)
|
||||
# sl_1 = StreamToLogger(stdout_logger, logging.INFO)
|
||||
# sys.stdout = sl_1
|
||||
#
|
||||
# stderr_logger = logging.getLogger("stderr")
|
||||
# stderr_logger.setLevel(logging.ERROR)
|
||||
# sl = StreamToLogger(stderr_logger, logging.ERROR)
|
||||
# sys.stderr = sl
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None:
|
||||
@ -78,6 +74,12 @@ def build_logger(logger_name, logger_filename):
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user