mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
style:fmt
This commit is contained in:
parent
0efaffc031
commit
71b9cd14a6
@ -20,8 +20,17 @@ CFG = Config()
|
||||
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||
|
||||
|
||||
def zh_font_set():
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
@ -29,11 +38,14 @@ def zh_font_set():
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
|
||||
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data",
|
||||
'"speak": "<speak>", "df":"<data frame>"')
|
||||
@command(
|
||||
"response_line_chart",
|
||||
"Line chart display, used to display comparative trend analysis data",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart:{speak},")
|
||||
|
||||
@ -44,7 +56,15 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
@ -52,31 +72,34 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {'font.sans-serif': can_use_fonts}
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=.5, s=.7)
|
||||
sns.set(context='notebook', style='ticks', rc=rc)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
|
||||
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"speak": "<speak>", "df":"<data frame>"')
|
||||
@command(
|
||||
"response_bar_chart",
|
||||
"Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_bar_chart:{speak},")
|
||||
columns = df.columns.tolist()
|
||||
@ -85,7 +108,15 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
@ -93,29 +124,32 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {'font.sans-serif': can_use_fonts}
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=.5, s=.7)
|
||||
sns.set(context='notebook', style='ticks', rc=rc)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
|
||||
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||
'"speak": "<speak>", "df":"<data frame>"')
|
||||
@command(
|
||||
"response_pie_chart",
|
||||
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_pie_chart:{speak},")
|
||||
columns = df.columns.tolist()
|
||||
@ -123,7 +157,15 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
raise ValueError("No Data!")
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
@ -131,23 +173,35 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
|
||||
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
|
||||
ax = df.plot(
|
||||
kind="pie",
|
||||
y=columns[1],
|
||||
ax=ax,
|
||||
labels=df[columns[0]].values,
|
||||
startangle=90,
|
||||
autopct="%1.1f%%",
|
||||
)
|
||||
# 手动设置 labels 的位置和大小
|
||||
ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10)
|
||||
ax.legend(
|
||||
loc="upper right",
|
||||
bbox_to_anchor=(0, 0, 1, 1),
|
||||
labels=df[columns[0]].values,
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
plt.axis('equal') # 使饼图为正圆形
|
||||
plt.axis("equal") # 使饼图为正圆形
|
||||
# plt.title(columns[0])
|
||||
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
|
||||
|
@ -11,8 +11,12 @@ CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
|
||||
|
||||
@command("response_table", "Table display, suitable for display with many display columns or non-numeric columns", '"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_table(speak: str, df: DataFrame) -> str:
|
||||
@command(
|
||||
"response_table",
|
||||
"Table display, suitable for display with many display columns or non-numeric columns",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_table(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_table:{speak}")
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
|
@ -10,9 +10,12 @@ CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
|
||||
|
||||
@command("response_data_text", "Text display, the default display method, suitable for single-line or simple content display",
|
||||
'"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_data_text(speak: str, df: DataFrame) -> str:
|
||||
@command(
|
||||
"response_data_text",
|
||||
"Text display, the default display method, suitable for single-line or simple content display",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_data_text(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_data_text:{speak}")
|
||||
data = df.values
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
def csv_colunm_foramt(val):
|
||||
if str(val).find("$") >= 0:
|
||||
return float(val.replace('$', '').replace(',', ''))
|
||||
return float(val.replace("$", "").replace(",", ""))
|
||||
if str(val).find("¥") >= 0:
|
||||
return float(val.replace('¥', '').replace(',', ''))
|
||||
return float(val.replace("¥", "").replace(",", ""))
|
||||
return val
|
||||
|
@ -262,7 +262,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def __write(self, write_sql):
|
||||
def __write(self, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self._engine.url.database
|
||||
result = self.session.execute(text(write_sql))
|
||||
@ -272,7 +272,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self,query, fetch: str = "all"):
|
||||
def __query(self, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
@ -325,6 +325,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
return []
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
@ -333,12 +334,12 @@ class RDBMSDatabase(BaseConnect):
|
||||
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self.__query( command, fetch)
|
||||
return self.__query(command, fetch)
|
||||
else:
|
||||
self.__write( command)
|
||||
self.__write(command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query( select_sql)
|
||||
return self.__query(select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
@ -351,10 +352,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
return self.__query( f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return result
|
||||
else:
|
||||
return self.__query( f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
@ -2,6 +2,6 @@ from pilot.configs.config import Config
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
|
||||
if __name__ == "__main__":
|
||||
mange= ConnectManager()
|
||||
mange = ConnectManager()
|
||||
types = mange.get_all_completed_types()
|
||||
print(str(types))
|
||||
print(str(types))
|
||||
|
@ -94,8 +94,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
|
||||
def update(self, messages:List[OnceConversation]) -> None:
|
||||
def update(self, messages: List[OnceConversation]) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||
@ -161,7 +160,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
|
@ -110,7 +110,6 @@ async def db_connect_delete(db_name: str = None):
|
||||
|
||||
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
|
||||
async def db_support_types():
|
||||
|
||||
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
|
||||
db_type_infos = []
|
||||
for type in support_types:
|
||||
@ -130,7 +129,7 @@ async def dialogue_list(user_id: str = None):
|
||||
chat_mode = item.get("chat_mode")
|
||||
|
||||
messages = json.loads(item.get("messages"))
|
||||
last_round = max(messages, key=lambda x: x['chat_order'])
|
||||
last_round = max(messages, key=lambda x: x["chat_order"])
|
||||
if "param_value" in last_round:
|
||||
select_param = last_round["param_value"]
|
||||
else:
|
||||
@ -139,7 +138,7 @@ async def dialogue_list(user_id: str = None):
|
||||
conv_uid=conv_uid,
|
||||
user_input=summary,
|
||||
chat_mode=chat_mode,
|
||||
select_param=select_param
|
||||
select_param=select_param,
|
||||
)
|
||||
dialogues.append(conv_vo)
|
||||
|
||||
@ -213,7 +212,9 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
|
||||
),
|
||||
)
|
||||
## chat prepare
|
||||
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename)
|
||||
dialogue = ConversationVo(
|
||||
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
|
||||
)
|
||||
chat: BaseChat = get_chat_instance(dialogue)
|
||||
resp = chat.prepare()
|
||||
|
||||
@ -229,7 +230,8 @@ async def dialogue_delete(con_uid: str):
|
||||
history_mem.delete()
|
||||
return Result.succ(None)
|
||||
|
||||
def get_hist_messages(conv_uid:str):
|
||||
|
||||
def get_hist_messages(conv_uid: str):
|
||||
message_vos: List[MessageVo] = []
|
||||
history_mem = DuckdbHistoryMemory(conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
@ -264,7 +266,7 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
"user_input": dialogue.user_input,
|
||||
"select_param": dialogue.select_param
|
||||
"select_param": dialogue.select_param,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
return chat
|
||||
|
@ -22,7 +22,7 @@ from pilot.openapi.editor_view_model import (
|
||||
ChartDetail,
|
||||
ChatChartEditContext,
|
||||
ChatSqlEditContext,
|
||||
DbTable
|
||||
DbTable,
|
||||
)
|
||||
|
||||
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
|
||||
@ -38,7 +38,9 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
||||
|
||||
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
|
||||
async def get_editor_tables(
|
||||
db_name: str, page_index: int, page_size: int, search_str: str = ""
|
||||
):
|
||||
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
tables = db_conn.get_table_names()
|
||||
@ -49,8 +51,15 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc
|
||||
fields = db_conn.get_fields(table)
|
||||
for field in fields:
|
||||
table_node.children.append(
|
||||
DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3],
|
||||
comment=field[-1]))
|
||||
DataNode(
|
||||
title=field[0],
|
||||
key=field[0],
|
||||
type=field[1],
|
||||
default_value=field[2],
|
||||
can_null=field[3],
|
||||
comment=field[-1],
|
||||
)
|
||||
)
|
||||
|
||||
return Result.succ(db_node)
|
||||
|
||||
@ -68,8 +77,11 @@ async def get_editor_sql_rounds(con_uid: str):
|
||||
if element["type"] == "human":
|
||||
round_name = element["data"]["content"]
|
||||
if once.get("param_value"):
|
||||
round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"],
|
||||
round_name=round_name)
|
||||
round: ChatDbRounds = ChatDbRounds(
|
||||
round=once["chat_order"],
|
||||
db_name=once["param_value"],
|
||||
round_name=round_name,
|
||||
)
|
||||
result.append(round)
|
||||
return Result.succ(result)
|
||||
|
||||
@ -84,8 +96,14 @@ async def get_editor_sql(con_uid: str, round: int):
|
||||
if int(once["chat_order"]) == round:
|
||||
for element in once["messages"]:
|
||||
if element["type"] == "ai":
|
||||
logger.info(f'history ai json resp:{element["data"]["content"]}')
|
||||
context = element["data"]["content"].replace("\\n", " ").replace("\n", " ")
|
||||
logger.info(
|
||||
f'history ai json resp:{element["data"]["content"]}'
|
||||
)
|
||||
context = (
|
||||
element["data"]["content"]
|
||||
.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
)
|
||||
return Result.succ(json.loads(context))
|
||||
return Result.faild(msg="not have sql!")
|
||||
|
||||
@ -93,8 +111,8 @@ async def get_editor_sql(con_uid: str, round: int):
|
||||
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
||||
async def editor_sql_run(run_param: dict = Body()):
|
||||
logger.info(f"editor_sql_run:{run_param}")
|
||||
db_name = run_param['db_name']
|
||||
sql = run_param['sql']
|
||||
db_name = run_param["db_name"]
|
||||
sql = run_param["sql"]
|
||||
if not db_name and not sql:
|
||||
return Result.faild("SQL run param error!")
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
@ -104,18 +122,17 @@ async def editor_sql_run(run_param: dict = Body()):
|
||||
colunms, sql_result = conn.query_ex(sql)
|
||||
# 计算执行耗时
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result
|
||||
)
|
||||
sql_run_data: SqlRunData = SqlRunData(
|
||||
result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result,
|
||||
)
|
||||
return Result.succ(sql_run_data)
|
||||
except Exception as e:
|
||||
return Result.succ(SqlRunData(result_info=str(e),
|
||||
run_cost=0,
|
||||
colunms=[],
|
||||
values=[]
|
||||
))
|
||||
return Result.succ(
|
||||
SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/sql/editor/submit")
|
||||
@ -126,18 +143,24 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
if history_messages:
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
|
||||
|
||||
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0]
|
||||
edit_round = list(
|
||||
filter(
|
||||
lambda x: x["chat_order"] == sql_edit_context.conv_round,
|
||||
history_messages,
|
||||
)
|
||||
)[0]
|
||||
if edit_round:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
db_resp = json.loads(element["data"]["content"])
|
||||
db_resp['thoughts'] = sql_edit_context.new_speak
|
||||
db_resp['sql'] = sql_edit_context.new_sql
|
||||
db_resp["thoughts"] = sql_edit_context.new_speak
|
||||
db_resp["sql"] = sql_edit_context.new_sql
|
||||
element["data"]["content"] = json.dumps(db_resp)
|
||||
if element["type"] == "view":
|
||||
data_loader = DbDataLoader()
|
||||
element["data"]["content"] = data_loader.get_table_view_by_conn(conn.run(sql_edit_context.new_sql),
|
||||
sql_edit_context.new_speak)
|
||||
element["data"]["content"] = data_loader.get_table_view_by_conn(
|
||||
conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
|
||||
)
|
||||
history_mem.update(history_messages)
|
||||
return Result.succ(None)
|
||||
return Result.faild(msg="Edit Faild!")
|
||||
@ -145,16 +168,21 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
|
||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
|
||||
async def get_editor_chart_list(con_uid: str):
|
||||
logger.info(f"get_editor_sql_rounds:{con_uid}", )
|
||||
logger.info(
|
||||
f"get_editor_sql_rounds:{con_uid}",
|
||||
)
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
db_name = last_round["param_value"]
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "ai":
|
||||
chart_list: ChartList = ChartList(round=last_round['chat_order'], db_name=db_name,
|
||||
charts=json.loads(element["data"]["content"]))
|
||||
chart_list: ChartList = ChartList(
|
||||
round=last_round["chat_order"],
|
||||
db_name=db_name,
|
||||
charts=json.loads(element["data"]["content"]),
|
||||
)
|
||||
return Result.succ(chart_list)
|
||||
return Result.faild(msg="Not have charts!")
|
||||
|
||||
@ -162,33 +190,40 @@ async def get_editor_chart_list(con_uid: str):
|
||||
@router.post("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart_info(param: dict = Body()):
|
||||
logger.info(f"get_editor_chart_info:{param}")
|
||||
conv_uid = param['con_uid']
|
||||
chart_title = param['chart_title']
|
||||
conv_uid = param["con_uid"]
|
||||
chart_title = param["chart_title"]
|
||||
|
||||
history_mem = DuckdbHistoryMemory(conv_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
last_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||
last_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
db_name = last_round["param_value"]
|
||||
if not db_name:
|
||||
logger.error("this dashboard dialogue version too old, can't support editor!")
|
||||
return Result.faild(msg="this dashboard dialogue version too old, can't support editor!")
|
||||
logger.error(
|
||||
"this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
return Result.faild(
|
||||
msg="this dashboard dialogue version too old, can't support editor!"
|
||||
)
|
||||
for element in last_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
view_data: dict = json.loads(element["data"]["content"])
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_title, charts))[0]
|
||||
find_chart = list(
|
||||
filter(lambda x: x["chart_name"] == chart_title, charts)
|
||||
)[0]
|
||||
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'],
|
||||
chart_type=find_chart['chart_type'],
|
||||
chart_desc=find_chart['chart_desc'],
|
||||
chart_sql=find_chart['chart_sql'],
|
||||
db_name=db_name,
|
||||
chart_name=find_chart['chart_name'],
|
||||
chart_value=find_chart['values'],
|
||||
table_value=conn.run(find_chart['chart_sql'])
|
||||
)
|
||||
detail: ChartDetail = ChartDetail(
|
||||
chart_uid=find_chart["chart_uid"],
|
||||
chart_type=find_chart["chart_type"],
|
||||
chart_desc=find_chart["chart_desc"],
|
||||
chart_sql=find_chart["chart_sql"],
|
||||
db_name=db_name,
|
||||
chart_name=find_chart["chart_name"],
|
||||
chart_value=find_chart["values"],
|
||||
table_value=conn.run(find_chart["chart_sql"]),
|
||||
)
|
||||
|
||||
return Result.succ(detail)
|
||||
return Result.faild(msg="Can't Find Chart Detail Info!")
|
||||
@ -197,36 +232,39 @@ async def get_editor_chart_info(param: dict = Body()):
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||
async def editor_chart_run(run_param: dict = Body()):
|
||||
logger.info(f"editor_chart_run:{run_param}")
|
||||
db_name = run_param['db_name']
|
||||
sql = run_param['sql']
|
||||
chart_type = run_param['chart_type']
|
||||
db_name = run_param["db_name"]
|
||||
sql = run_param["sql"]
|
||||
chart_type = run_param["chart_type"]
|
||||
if not db_name and not sql:
|
||||
return Result.faild("SQL run param error!")
|
||||
try:
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
colunms, sql_result = db_conn.query_ex(sql)
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(colunms, sql_result, sql)
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
|
||||
colunms, sql_result, sql
|
||||
)
|
||||
|
||||
start_time = time.time() * 1000
|
||||
# 计算执行耗时
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result
|
||||
)
|
||||
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values, chart_type = chart_type))
|
||||
sql_run_data: SqlRunData = SqlRunData(
|
||||
result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
colunms=colunms,
|
||||
values=sql_result,
|
||||
)
|
||||
return Result.succ(
|
||||
ChartRunData(
|
||||
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
sql_result = SqlRunData(result_info=str(e),
|
||||
run_cost=0,
|
||||
colunms=[],
|
||||
values=[]
|
||||
)
|
||||
return Result.succ(ChartRunData(sql_data = sql_result,
|
||||
chart_values=[],
|
||||
chart_type = chart_type
|
||||
))
|
||||
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
|
||||
return Result.succ(
|
||||
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||
@ -237,35 +275,53 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
|
||||
|
||||
edit_round = max(history_messages, key=lambda x: x['chat_order'])
|
||||
edit_round = max(history_messages, key=lambda x: x["chat_order"])
|
||||
if edit_round:
|
||||
try:
|
||||
for element in edit_round["messages"]:
|
||||
if element["type"] == "view":
|
||||
view_data: dict = json.loads(element["data"]["content"]);
|
||||
view_data: dict = json.loads(element["data"]["content"])
|
||||
charts: List = view_data.get("charts")
|
||||
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_title, charts))[
|
||||
0]
|
||||
find_chart = list(
|
||||
filter(
|
||||
lambda x: x["chart_name"]
|
||||
== chart_edit_context.chart_title,
|
||||
charts,
|
||||
)
|
||||
)[0]
|
||||
if chart_edit_context.new_chart_type:
|
||||
find_chart['chart_type'] = chart_edit_context.new_chart_type
|
||||
find_chart["chart_type"] = chart_edit_context.new_chart_type
|
||||
if chart_edit_context.new_comment:
|
||||
find_chart['chart_desc'] = chart_edit_context.new_comment
|
||||
find_chart["chart_desc"] = chart_edit_context.new_comment
|
||||
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn,
|
||||
chart_edit_context.new_sql)
|
||||
find_chart['chart_sql'] = chart_edit_context.new_sql
|
||||
find_chart['values'] = [value.dict() for value in chart_values]
|
||||
find_chart['column_name'] = field_names
|
||||
(
|
||||
field_names,
|
||||
chart_values,
|
||||
) = dashboard_data_loader.get_chart_values_by_conn(
|
||||
db_conn, chart_edit_context.new_sql
|
||||
)
|
||||
find_chart["chart_sql"] = chart_edit_context.new_sql
|
||||
find_chart["values"] = [value.dict() for value in chart_values]
|
||||
find_chart["column_name"] = field_names
|
||||
|
||||
element["data"]["content"] = json.dumps(view_data, ensure_ascii=False)
|
||||
element["data"]["content"] = json.dumps(
|
||||
view_data, ensure_ascii=False
|
||||
)
|
||||
if element["type"] == "ai":
|
||||
ai_resp: dict = json.loads(element["data"]["content"])
|
||||
edit_item = list(filter(lambda x: x['title'] == chart_edit_context.chart_title, ai_resp))[0]
|
||||
edit_item = list(
|
||||
filter(
|
||||
lambda x: x["title"] == chart_edit_context.chart_title,
|
||||
ai_resp,
|
||||
)
|
||||
)[0]
|
||||
|
||||
edit_item["sql"] = chart_edit_context.new_sql
|
||||
edit_item["showcase"] = chart_edit_context.new_chart_type
|
||||
edit_item["thoughts"] = chart_edit_context.new_comment
|
||||
element["data"]["content"] = json.dumps(ai_resp, ensure_ascii=False)
|
||||
element["data"]["content"] = json.dumps(
|
||||
ai_resp, ensure_ascii=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"edit chart exception!{str(e)}", e)
|
||||
return Result.faild(msg=f"Edit chart exception!{str(e)}")
|
||||
|
@ -9,7 +9,7 @@ class DataNode(BaseModel):
|
||||
|
||||
type: str = ""
|
||||
default_value: str = None
|
||||
can_null: str = 'YES'
|
||||
can_null: str = "YES"
|
||||
comment: str = None
|
||||
children: List = []
|
||||
|
||||
|
@ -10,11 +10,13 @@ class DbField(BaseModel):
|
||||
default_value: str = ""
|
||||
comment: str = ""
|
||||
|
||||
|
||||
class DbTable(BaseModel):
|
||||
table_name: str
|
||||
comment: str
|
||||
colunm: List[DbField]
|
||||
|
||||
|
||||
class ChatDbRounds(BaseModel):
|
||||
round: int
|
||||
db_name: str
|
||||
@ -61,4 +63,3 @@ class ChatSqlEditContext(BaseModel):
|
||||
|
||||
new_sql: str
|
||||
new_speak: str = ""
|
||||
|
||||
|
@ -123,11 +123,7 @@ class BaseOutputParser(ABC):
|
||||
ai_response = ai_response.replace("\*", "*")
|
||||
ai_response = ai_response.replace("\t", "")
|
||||
|
||||
ai_response = (
|
||||
ai_response.strip()
|
||||
.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
)
|
||||
ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
|
||||
print("un_stream ai response:", ai_response)
|
||||
return ai_response
|
||||
else:
|
||||
@ -209,9 +205,9 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\\", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
cleaned_output = self.__illegal_json_ends(cleaned_output)
|
||||
return cleaned_output
|
||||
|
@ -12,7 +12,6 @@ class Scene:
|
||||
is_inner: bool = False,
|
||||
show_disable=False,
|
||||
prepare_scene_code: str = None,
|
||||
|
||||
):
|
||||
self.code = code
|
||||
self.name = name
|
||||
@ -22,38 +21,39 @@ class Scene:
|
||||
self.show_disable = show_disable
|
||||
self.prepare_scene_code = prepare_scene_code
|
||||
|
||||
|
||||
class ChatScene(Enum):
|
||||
ChatWithDbExecute = Scene(
|
||||
code = "chat_with_db_execute",
|
||||
name = "Chat Data",
|
||||
describe = "Dialogue with your private data through natural language.",
|
||||
param_types = ["DB Select"],
|
||||
code="chat_with_db_execute",
|
||||
name="Chat Data",
|
||||
describe="Dialogue with your private data through natural language.",
|
||||
param_types=["DB Select"],
|
||||
)
|
||||
ExcelLearning = Scene(
|
||||
code = "excel_learning",
|
||||
name = "Excel Learning",
|
||||
describe = "Analyze and summarize your excel files.",
|
||||
is_inner = True,
|
||||
code="excel_learning",
|
||||
name="Excel Learning",
|
||||
describe="Analyze and summarize your excel files.",
|
||||
is_inner=True,
|
||||
)
|
||||
ChatExcel = Scene(
|
||||
code = "chat_excel",
|
||||
name = "Chat Excel",
|
||||
describe = "Dialogue with your excel, use natural language.",
|
||||
code="chat_excel",
|
||||
name="Chat Excel",
|
||||
describe="Dialogue with your excel, use natural language.",
|
||||
param_types=["File Select"],
|
||||
prepare_scene_code="excel_learning"
|
||||
prepare_scene_code="excel_learning",
|
||||
)
|
||||
|
||||
ChatWithDbQA = Scene(
|
||||
code = "chat_with_db_qa",
|
||||
name = "Chat DB",
|
||||
describe = "Have a Professional Conversation with Metadata.",
|
||||
param_types = ["DB Select"],
|
||||
code="chat_with_db_qa",
|
||||
name="Chat DB",
|
||||
describe="Have a Professional Conversation with Metadata.",
|
||||
param_types=["DB Select"],
|
||||
)
|
||||
ChatExecution = Scene(
|
||||
code = "chat_execution",
|
||||
name = "Use Plugin",
|
||||
describe = "Use tools through dialogue to accomplish your goals.",
|
||||
param_types = ["Plugin Select"],
|
||||
code="chat_execution",
|
||||
name="Use Plugin",
|
||||
describe="Use tools through dialogue to accomplish your goals.",
|
||||
param_types=["Plugin Select"],
|
||||
)
|
||||
|
||||
InnerChatDBSummary = Scene(
|
||||
@ -78,7 +78,7 @@ class ChatScene(Enum):
|
||||
|
||||
@staticmethod
|
||||
def of_mode(mode):
|
||||
return [x for x in ChatScene._value_ if x.code == mode][0]
|
||||
return [x for x in ChatScene._value_ if x.code == mode][0]
|
||||
|
||||
@staticmethod
|
||||
def is_valid_mode(mode):
|
||||
@ -100,4 +100,4 @@ class ChatScene(Enum):
|
||||
return self._value_.show_disable
|
||||
|
||||
def is_inner(self):
|
||||
return self._value_.is_inner
|
||||
return self._value_.is_inner
|
||||
|
@ -60,11 +60,7 @@ class BaseChat(ABC):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
select_param: Any = None
|
||||
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
@ -72,7 +68,6 @@ class BaseChat(ABC):
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
self.llm_echo = False
|
||||
|
||||
|
||||
### load prompt template
|
||||
# self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
# self.chat_mode.value()
|
||||
@ -182,6 +177,7 @@ class BaseChat(ABC):
|
||||
return response
|
||||
else:
|
||||
from pilot.server.llmserver import worker
|
||||
|
||||
return worker.generate_stream_gate(payload)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
@ -235,7 +231,9 @@ class BaseChat(ABC):
|
||||
### llm speaker
|
||||
speak_to_user = self.get_llm_speak(prompt_define_response)
|
||||
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(
|
||||
speak_to_user, result
|
||||
)
|
||||
self.current_message.add_view_message(view_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
@ -253,10 +251,8 @@ class BaseChat(ABC):
|
||||
else:
|
||||
return self.nostream_call()
|
||||
|
||||
|
||||
def prepare(self):
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
def generate_llm_text(self) -> str:
|
||||
warnings.warn("This method is deprecated - please use `generate_llm_messages`.")
|
||||
@ -363,9 +359,7 @@ class BaseChat(ABC):
|
||||
)
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
for first_message in self.history_message[0]["messages"]:
|
||||
if not first_message["type"] in [
|
||||
ModelMessageRoleType.VIEW
|
||||
]:
|
||||
if not first_message["type"] in [ModelMessageRoleType.VIEW]:
|
||||
message_type = first_message["type"]
|
||||
message_content = first_message["data"]["content"]
|
||||
history_text += (
|
||||
@ -394,7 +388,9 @@ class BaseChat(ABC):
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
history_messages.append(
|
||||
ModelMessage(role=message_type, content=message_content)
|
||||
ModelMessage(
|
||||
role=message_type, content=message_content
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -63,6 +63,7 @@ class AIMessage(BaseMessage):
|
||||
|
||||
class ViewMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the AI."""
|
||||
|
||||
example: bool = False
|
||||
|
||||
@property
|
||||
@ -73,6 +74,7 @@ class ViewMessage(BaseMessage):
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""Type of message that is a system message."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
|
@ -21,9 +21,15 @@ class ChatDashboard(BaseChat):
|
||||
report_name: str
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param:str = "", report_name:str="report"):
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
select_param: str = "",
|
||||
report_name: str = "report",
|
||||
):
|
||||
""" """
|
||||
self.db_name=select_param
|
||||
self.db_name = select_param
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatDashboard,
|
||||
chat_session_id=chat_session_id,
|
||||
@ -80,7 +86,9 @@ class ChatDashboard(BaseChat):
|
||||
dashboard_data_loader = DashboardDataLoader()
|
||||
for chart_item in prompt_response:
|
||||
try:
|
||||
field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.database, chart_item.sql)
|
||||
field_names, values = dashboard_data_loader.get_chart_values_by_conn(
|
||||
self.database, chart_item.sql
|
||||
)
|
||||
chart_datas.append(
|
||||
ChartData(
|
||||
chart_uid=str(uuid.uuid1()),
|
||||
@ -101,5 +109,3 @@ class ChatDashboard(BaseChat):
|
||||
template_introduce=None,
|
||||
charts=chart_datas,
|
||||
)
|
||||
|
||||
|
||||
|
@ -11,15 +11,14 @@ logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
|
||||
|
||||
|
||||
class DashboardDataLoader:
|
||||
|
||||
def get_sql_value(self, db_conn, chart_sql: str):
|
||||
return db_conn.query_ex(chart_sql)
|
||||
return db_conn.query_ex(chart_sql)
|
||||
|
||||
def get_chart_values_by_conn(self, db_conn, chart_sql: str) :
|
||||
field_names, datas = db_conn.query_ex(chart_sql)
|
||||
return self.get_chart_values_by_data(field_names, datas, chart_sql)
|
||||
def get_chart_values_by_conn(self, db_conn, chart_sql: str):
|
||||
field_names, datas = db_conn.query_ex(chart_sql)
|
||||
return self.get_chart_values_by_data(field_names, datas, chart_sql)
|
||||
|
||||
def get_chart_values_by_data(self, field_names, datas, chart_sql: str) :
|
||||
def get_chart_values_by_data(self, field_names, datas, chart_sql: str):
|
||||
logger.info(f"get_chart_values_by_conn:{chart_sql}")
|
||||
try:
|
||||
values: List[ValueItem] = []
|
||||
@ -57,7 +56,7 @@ class DashboardDataLoader:
|
||||
logger.debug("Prepare Chart Data Faild!" + str(e))
|
||||
raise ValueError("Prepare Chart Data Faild!")
|
||||
|
||||
def get_chart_values_by_db(self, db_name: str, chart_sql: str) :
|
||||
def get_chart_values_by_db(self, db_name: str, chart_sql: str):
|
||||
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
return self.get_chart_values_by_conn(db_conn, chart_sql)
|
||||
|
@ -27,6 +27,7 @@ CFG = Config()
|
||||
class ChatExcel(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatExcel.value()
|
||||
chat_retention_rounds = 1
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
chat_mode = ChatScene.ChatExcel
|
||||
|
||||
@ -34,9 +35,11 @@ class ChatExcel(BaseChat):
|
||||
if has_path(select_param):
|
||||
self.excel_reader = ExcelReader(select_param)
|
||||
else:
|
||||
self.excel_reader = ExcelReader(os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
|
||||
))
|
||||
self.excel_reader = ExcelReader(
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
@ -70,12 +73,9 @@ class ChatExcel(BaseChat):
|
||||
]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
|
||||
|
||||
def generate_input_values(self):
|
||||
|
||||
|
||||
input_values = {
|
||||
"user_input": self.current_user_input,
|
||||
"user_input": self.current_user_input,
|
||||
"table_name": self.excel_reader.table_name,
|
||||
"disply_type": self._generate_numbered_list(),
|
||||
}
|
||||
@ -87,23 +87,21 @@ class ChatExcel(BaseChat):
|
||||
return None
|
||||
chat_param = {
|
||||
"chat_session_id": self.chat_session_id,
|
||||
"user_input": "[" + self.excel_reader.excel_file_name +"]" + " Analysis!",
|
||||
"user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analysis!",
|
||||
"parent_mode": self.chat_mode,
|
||||
"select_param":self.excel_reader.excel_file_name,
|
||||
"excel_reader": self.excel_reader
|
||||
"select_param": self.excel_reader.excel_file_name,
|
||||
"excel_reader": self.excel_reader,
|
||||
}
|
||||
learn_chat = ExcelLearning(**chat_param)
|
||||
result = learn_chat.nostream_call()
|
||||
return result
|
||||
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
print(f"do_action:{prompt_response}")
|
||||
|
||||
# colunms, datas = self.excel_reader.run(prompt_response.sql)
|
||||
param= {
|
||||
param = {
|
||||
"speak": prompt_response.thoughts,
|
||||
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql)
|
||||
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
|
||||
}
|
||||
return CFG.command_disply.call(prompt_response.display, **param)
|
||||
|
||||
|
@ -2,7 +2,9 @@ import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import ChatExcelOutputParser
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import (
|
||||
ChatExcelOutputParser,
|
||||
)
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
@ -39,9 +41,9 @@ SQL中需要使用的表名是: {table_name}
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"sql": "analysis SQL",
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
"display": "display type name"
|
||||
"sql": "analysis SQL",
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
"display": "display type name",
|
||||
}
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
@ -67,9 +69,8 @@ prompt = PromptTemplate(
|
||||
output_parser=ChatExcelOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
need_historical_messages = True,
|
||||
need_historical_messages=True,
|
||||
# example_selector=sql_data_example,
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||
|
||||
|
@ -23,7 +23,14 @@ CFG = Config()
|
||||
class ExcelLearning(BaseChat):
|
||||
chat_scene: str = ChatScene.ExcelLearning.value()
|
||||
|
||||
def __init__(self, chat_session_id, user_input, parent_mode: Any=None, select_param:str=None, excel_reader:Any=None):
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
parent_mode: Any = None,
|
||||
select_param: str = None,
|
||||
excel_reader: Any = None,
|
||||
):
|
||||
chat_mode = ChatScene.ExcelLearning
|
||||
""" """
|
||||
self.excel_file_path = select_param
|
||||
@ -31,20 +38,19 @@ class ExcelLearning(BaseChat):
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input = user_input,
|
||||
current_user_input=user_input,
|
||||
select_param=select_param,
|
||||
)
|
||||
if parent_mode:
|
||||
self.current_message.chat_mode = parent_mode.value()
|
||||
|
||||
def generate_input_values(self):
|
||||
|
||||
colunms, datas = self.excel_reader.get_sample_data()
|
||||
datas.insert(0, colunms)
|
||||
|
||||
input_values = {
|
||||
"data_example": json.dumps(self.excel_reader.get_sample_data(), cls=DateTimeEncoder),
|
||||
"data_example": json.dumps(
|
||||
self.excel_reader.get_sample_data(), cls=DateTimeEncoder
|
||||
),
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
||||
|
@ -37,23 +37,23 @@ class LearningExcelOutputParser(BaseOutputParser):
|
||||
plans = response[key]
|
||||
return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans)
|
||||
|
||||
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
html_title = f"### **数据简介**\n{data.desciption} "
|
||||
html_colunms = f"### **数据结构**\n"
|
||||
column_index = 0
|
||||
for item in data.clounms:
|
||||
column_index +=1
|
||||
column_index += 1
|
||||
keys = item.keys()
|
||||
for key in keys:
|
||||
html_colunms = html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
|
||||
html_colunms = (
|
||||
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
|
||||
)
|
||||
|
||||
html_plans = f"### **分析计划**\n"
|
||||
index = 0
|
||||
for item in data.plans:
|
||||
index +=1
|
||||
index += 1
|
||||
html_plans = html_plans + f"{item} \n"
|
||||
html = f"""{html_title}\n{html_colunms}\n{html_plans}"""
|
||||
return html
|
||||
|
@ -2,7 +2,9 @@ import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import LearningExcelOutputParser
|
||||
from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import (
|
||||
LearningExcelOutputParser,
|
||||
)
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
@ -29,7 +31,7 @@ _DEFAULT_TEMPLATE_ZH = """
|
||||
{response}
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"DataAnalysis": "数据内容分析总结",
|
||||
"ColumnAnalysis": [{"column name1": "字段1介绍,专业术语解释(请尽量简单明了)"}],
|
||||
"AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||
@ -63,5 +65,3 @@ prompt = PromptTemplate(
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
from fsspec import filesystem
|
||||
import spatial
|
||||
import spatial
|
||||
|
||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
|
||||
@ -18,22 +18,35 @@ if __name__ == "__main__":
|
||||
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
|
||||
|
||||
# colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
|
||||
colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """)
|
||||
df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;")
|
||||
colunms, datas = excel_reader.run(
|
||||
""" SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """
|
||||
)
|
||||
df = excel_reader.get_df_by_sql_ex(
|
||||
"SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;"
|
||||
)
|
||||
columns = df.columns.tolist()
|
||||
plt.rcParams["font.family"] = ["sans-serif"]
|
||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
||||
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
|
||||
sns.set_style(rc={"font.sans-serif": "Microsoft Yahei"})
|
||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
|
||||
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
plt.subplots_adjust(top=0.9)
|
||||
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
|
||||
ax = df.plot(
|
||||
kind="pie",
|
||||
y=columns[1],
|
||||
ax=ax,
|
||||
labels=df[columns[0]].values,
|
||||
startangle=90,
|
||||
autopct="%1.1f%%",
|
||||
)
|
||||
# 手动设置 labels 的位置和大小
|
||||
ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10)
|
||||
plt.axis('equal') # 使饼图为正圆形
|
||||
ax.legend(
|
||||
loc="center left", bbox_to_anchor=(-1, 0.5, 0, 0), labels=None, fontsize=10
|
||||
)
|
||||
plt.axis("equal") # 使饼图为正圆形
|
||||
plt.show()
|
||||
#
|
||||
#
|
||||
|
@ -7,11 +7,13 @@ import numpy as np
|
||||
|
||||
from pilot.common.pd_utils import csv_colunm_foramt
|
||||
|
||||
def excel_colunm_format(old_name:str)->str:
|
||||
|
||||
def excel_colunm_format(old_name: str) -> str:
|
||||
new_column = old_name.strip()
|
||||
new_column = new_column.replace(" ", "_")
|
||||
return new_column
|
||||
|
||||
|
||||
def add_quotes(sql, column_names=[]):
|
||||
sql = sql.replace("`", "")
|
||||
parsed = sqlparse.parse(sql)
|
||||
@ -20,44 +22,51 @@ def add_quotes(sql, column_names=[]):
|
||||
deep_quotes(token, column_names)
|
||||
return str(parsed[0])
|
||||
|
||||
|
||||
def deep_quotes(token, column_names=[]):
|
||||
if hasattr(token, "tokens") :
|
||||
if hasattr(token, "tokens"):
|
||||
for token_child in token.tokens:
|
||||
deep_quotes(token_child, column_names)
|
||||
else:
|
||||
if token.ttype == sqlparse.tokens.Name:
|
||||
if len(column_names) >0:
|
||||
if len(column_names) > 0:
|
||||
if token.value in column_names:
|
||||
token.value = f'"{token.value.replace("`", "")}"'
|
||||
else:
|
||||
token.value = f'"{token.value.replace("`", "")}"'
|
||||
|
||||
|
||||
def is_chinese(string):
|
||||
# 使用正则表达式匹配中文字符
|
||||
pattern = re.compile(r'[一-龥]')
|
||||
pattern = re.compile(r"[一-龥]")
|
||||
match = re.search(pattern, string)
|
||||
return match is not None
|
||||
|
||||
|
||||
class ExcelReader:
|
||||
|
||||
def __init__(self, file_path):
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
file_name_without_extension = os.path.splitext(file_name)[0]
|
||||
|
||||
self.excel_file_name = file_name
|
||||
self.extension = os.path.splitext(file_name)[1]
|
||||
# read excel file
|
||||
if file_path.endswith('.xlsx') or file_path.endswith('.xls'):
|
||||
if file_path.endswith(".xlsx") or file_path.endswith(".xls"):
|
||||
df_tmp = pd.read_excel(file_path)
|
||||
self.df = pd.read_excel(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])})
|
||||
elif file_path.endswith('.csv'):
|
||||
self.df = pd.read_excel(
|
||||
file_path,
|
||||
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
|
||||
)
|
||||
elif file_path.endswith(".csv"):
|
||||
df_tmp = pd.read_csv(file_path)
|
||||
self.df = pd.read_csv(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])})
|
||||
self.df = pd.read_csv(
|
||||
file_path,
|
||||
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported file format.")
|
||||
|
||||
self.df.replace('', np.nan, inplace=True)
|
||||
self.df.replace("", np.nan, inplace=True)
|
||||
self.columns_map = {}
|
||||
for column_name in df_tmp.columns:
|
||||
self.columns_map.update({column_name: excel_colunm_format(column_name)})
|
||||
@ -66,18 +75,17 @@ class ExcelReader:
|
||||
except Exception as e:
|
||||
print("transfor column error!" + column_name)
|
||||
|
||||
self.df = self.df.rename(columns=lambda x: x.strip().replace(' ', '_'))
|
||||
self.df = self.df.rename(columns=lambda x: x.strip().replace(" ", "_"))
|
||||
|
||||
# connect DuckDB
|
||||
self.db = duckdb.connect(database=':memory:', read_only=False)
|
||||
|
||||
self.db = duckdb.connect(database=":memory:", read_only=False)
|
||||
|
||||
self.table_name = file_name_without_extension
|
||||
# write data in duckdb
|
||||
self.db.register(self.table_name, self.df)
|
||||
|
||||
def run(self, sql):
|
||||
if f'"{self.table_name}"' not in sql:
|
||||
if f'"{self.table_name}"' not in sql:
|
||||
sql = sql.replace(self.table_name, f'"{self.table_name}"')
|
||||
sql = add_quotes(sql, self.columns_map.values())
|
||||
print(f"excute sql:{sql}")
|
||||
@ -92,5 +100,4 @@ class ExcelReader:
|
||||
return pd.DataFrame(values, columns=colunms)
|
||||
|
||||
def get_sample_data(self):
|
||||
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;')
|
||||
|
||||
return self.run(f"SELECT * FROM {self.table_name} LIMIT 5;")
|
||||
|
@ -21,7 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param:str = ""):
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
chat_mode = ChatScene.ChatWithDbExecute
|
||||
self.db_name = select_param
|
||||
""" """
|
||||
@ -47,7 +47,9 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
client = DBSummaryClient()
|
||||
try:
|
||||
table_infos = client.get_db_summary(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
dbname=self.db_name,
|
||||
query=self.current_user_input,
|
||||
topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
@ -65,4 +67,4 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
print(f"do_action:{prompt_response}")
|
||||
return self.database.run( prompt_response.sql)
|
||||
return self.database.run(prompt_response.sql)
|
||||
|
@ -8,6 +8,7 @@ from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@ -49,6 +50,4 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
return view_text
|
||||
else:
|
||||
return data_loader.get_table_view_by_conn(data, speak)
|
||||
|
||||
|
||||
return data_loader.get_table_view_by_conn(data, speak)
|
||||
|
@ -1,8 +1,7 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DbDataLoader:
|
||||
|
||||
|
||||
def get_table_view_by_conn(self, data, speak):
|
||||
### tool out data to table view
|
||||
if len(data) <= 1:
|
||||
@ -12,4 +11,4 @@ class DbDataLoader:
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
return view_text
|
||||
return view_text
|
||||
|
@ -19,7 +19,7 @@ class ChatWithDbQA(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param:str = ""):
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
""" """
|
||||
self.db_name = select_param
|
||||
super().__init__(
|
||||
|
@ -20,12 +20,7 @@ class ChatWithPlugin(BaseChat):
|
||||
plugins_prompt_generator: PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
select_param: str = None
|
||||
):
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
self.plugin_selector = select_param
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
|
@ -30,7 +30,7 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
""" """
|
||||
self.knowledge_space = select_param
|
||||
super().__init__(
|
||||
|
@ -30,7 +30,6 @@ class ChatNormal(BaseChat):
|
||||
input_values = {"input": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatNormal.value
|
||||
|
@ -117,12 +117,10 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
"tokens": once.tokens if once.tokens else 0,
|
||||
"messages": messages_to_dict(once.messages),
|
||||
"param_type": once.param_type,
|
||||
"param_value": once.param_value
|
||||
"param_value": once.param_value,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
||||
return [_conversation_to_dic(m) for m in conversations]
|
||||
|
||||
|
@ -64,16 +64,12 @@ def server_init(args):
|
||||
|
||||
cfg.command_registry = command_registry
|
||||
|
||||
|
||||
command_disply_commands = [
|
||||
"pilot.commands.disply_type.show_chart_gen",
|
||||
"pilot.commands.disply_type.show_table_gen",
|
||||
"pilot.commands.disply_type.show_text_gen",
|
||||
]
|
||||
command_disply_registry = CommandRegistry()
|
||||
command_disply_registry = CommandRegistry()
|
||||
for command in command_disply_commands:
|
||||
command_disply_registry.import_commands(command)
|
||||
cfg.command_disply = command_disply_registry
|
||||
|
||||
|
||||
|
||||
|
@ -81,7 +81,9 @@ app.include_router(knowledge_router)
|
||||
# app.include_router(api_editor_route_v1)
|
||||
|
||||
os.makedirs(static_message_img_path, exist_ok=True)
|
||||
app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2")
|
||||
app.mount(
|
||||
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
|
||||
)
|
||||
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
||||
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
||||
|
||||
|
@ -19,7 +19,6 @@ from pilot.utils import build_logger
|
||||
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
|
||||
|
||||
|
||||
|
||||
CFG = Config()
|
||||
chat_factory = ChatFactory()
|
||||
|
||||
@ -145,7 +144,9 @@ class DBSummaryClient:
|
||||
try:
|
||||
self.db_summary_embedding(item["db_name"], item["db_type"])
|
||||
except Exception as e:
|
||||
logger.warn(f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e)
|
||||
logger.warn(
|
||||
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e
|
||||
)
|
||||
|
||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||
profile_store_config = {
|
||||
|
@ -75,7 +75,6 @@ def build_logger(logger_name, logger_filename):
|
||||
item.addHandler(handler)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
Loading…
Reference in New Issue
Block a user