style:fmt

This commit is contained in:
aries_ckt 2023-08-29 19:57:33 +08:00
parent 0efaffc031
commit 71b9cd14a6
36 changed files with 413 additions and 280 deletions

View File

@ -20,8 +20,17 @@ CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log") logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
static_message_img_path = os.path.join(os.getcwd(), "message/img") static_message_img_path = os.path.join(os.getcwd(), "message/img")
def zh_font_set(): def zh_font_set():
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -29,11 +38,14 @@ def zh_font_set():
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_line_chart",
"Line chart display, used to display comparative trend analysis data",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_line_chart(speak: str, df: DataFrame) -> str: def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},") logger.info(f"response_line_chart:{speak},")
@ -44,7 +56,15 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
# set font # set font
# zh_font_set() # zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -52,31 +72,34 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts} rc = {"font.sans-serif": can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark") sns.set_style("dark")
sns.color_palette("hls", 10) sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7) sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context='notebook', style='ticks', rc=rc) sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax) sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
chart_name = "line_" + str(uuid.uuid1()) + ".png" chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_bar_chart(speak: str, df: DataFrame) -> str: def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},") logger.info(f"response_bar_chart:{speak},")
columns = df.columns.tolist() columns = df.columns.tolist()
@ -85,7 +108,15 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
# set font # set font
# zh_font_set() # zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -93,29 +124,32 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts} rc = {"font.sans-serif": can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark") sns.set_style("dark")
sns.color_palette("hls", 10) sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7) sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context='notebook', style='ticks', rc=rc) sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax) sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
chart_name = "pie_" + str(uuid.uuid1()) + ".png" chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_pie_chart(speak: str, df: DataFrame) -> str: def response_pie_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_pie_chart:{speak},") logger.info(f"response_pie_chart:{speak},")
columns = df.columns.tolist() columns = df.columns.tolist()
@ -123,7 +157,15 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
raise ValueError("No Data") raise ValueError("No Data")
# set font # set font
# zh_font_set() # zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -131,23 +173,35 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
# 手动设置 labels 的位置和大小 # 手动设置 labels 的位置和大小
ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10) ax.legend(
loc="upper right",
bbox_to_anchor=(0, 0, 1, 1),
labels=df[columns[0]].values,
fontsize=10,
)
plt.axis('equal') # 使饼图为正圆形 plt.axis("equal") # 使饼图为正圆形
# plt.title(columns[0]) # plt.title(columns[0])
chart_name = "pie_" + str(uuid.uuid1()) + ".png" chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""

View File

@ -11,8 +11,12 @@ CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_table", "Table display, suitable for display with many display columns or non-numeric columns", '"speak": "<speak>", "df":"<data frame>"') @command(
def response_table(speak: str, df: DataFrame) -> str: "response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_table(speak: str, df: DataFrame) -> str:
logger.info(f"response_table:{speak}") logger.info(f"response_table:{speak}")
html_table = df.to_html(index=False, escape=False, sparsify=False) html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split()) table_str = "".join(html_table.split())

View File

@ -10,9 +10,12 @@ CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_data_text", "Text display, the default display method, suitable for single-line or simple content display", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_data_text",
def response_data_text(speak: str, df: DataFrame) -> str: "Text display, the default display method, suitable for single-line or simple content display",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_data_text(speak: str, df: DataFrame) -> str:
logger.info(f"response_data_text:{speak}") logger.info(f"response_data_text:{speak}")
data = df.values data = df.values

View File

@ -1,6 +1,6 @@
def csv_colunm_foramt(val): def csv_colunm_foramt(val):
if str(val).find("$") >= 0: if str(val).find("$") >= 0:
return float(val.replace('$', '').replace(',', '')) return float(val.replace("$", "").replace(",", ""))
if str(val).find("¥") >= 0: if str(val).find("¥") >= 0:
return float(val.replace('¥', '').replace(',', '')) return float(val.replace("¥", "").replace(",", ""))
return val return val

View File

@ -262,7 +262,7 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message""" """Format the error message"""
return f"Error: {e}" return f"Error: {e}"
def __write(self, write_sql): def __write(self, write_sql):
print(f"Write[{write_sql}]") print(f"Write[{write_sql}]")
db_cache = self._engine.url.database db_cache = self._engine.url.database
result = self.session.execute(text(write_sql)) result = self.session.execute(text(write_sql))
@ -272,7 +272,7 @@ class RDBMSDatabase(BaseConnect):
print(f"SQL[{write_sql}], result:{result.rowcount}") print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount return result.rowcount
def __query(self,query, fetch: str = "all"): def __query(self, query, fetch: str = "all"):
""" """
only for query only for query
Args: Args:
@ -325,6 +325,7 @@ class RDBMSDatabase(BaseConnect):
result = list(result) result = list(result)
return field_names, result return field_names, result
return [] return []
def run(self, command: str, fetch: str = "all") -> List: def run(self, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.""" """Execute a SQL command and return a string representing the results."""
print("SQL:" + command) print("SQL:" + command)
@ -333,12 +334,12 @@ class RDBMSDatabase(BaseConnect):
parsed, ttype, sql_type, table_name = self.__sql_parse(command) parsed, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML: if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT": if sql_type == "SELECT":
return self.__query( command, fetch) return self.__query(command, fetch)
else: else:
self.__write( command) self.__write(command)
select_sql = self.convert_sql_write_to_select(command) select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}") print(f"write result query:{select_sql}")
return self.__query( select_sql) return self.__query(select_sql)
else: else:
print(f"DDL execution determines whether to enable through configuration ") print(f"DDL execution determines whether to enable through configuration ")
@ -351,10 +352,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names) result.insert(0, field_names)
print("DDL Result:" + str(result)) print("DDL Result:" + str(result))
if not result: if not result:
return self.__query( f"SHOW COLUMNS FROM {table_name}") return self.__query(f"SHOW COLUMNS FROM {table_name}")
return result return result
else: else:
return self.__query( f"SHOW COLUMNS FROM {table_name}") return self.__query(f"SHOW COLUMNS FROM {table_name}")
def run_no_throw(self, session, command: str, fetch: str = "all") -> List: def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results. """Execute a SQL command and return a string representing the results.

View File

@ -2,6 +2,6 @@ from pilot.configs.config import Config
from pilot.connections.manages.connection_manager import ConnectManager from pilot.connections.manages.connection_manager import ConnectManager
if __name__ == "__main__": if __name__ == "__main__":
mange= ConnectManager() mange = ConnectManager()
types = mange.get_all_completed_types() types = mange.get_all_completed_types()
print(str(types)) print(str(types))

View File

@ -94,8 +94,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def update(self, messages: List[OnceConversation]) -> None:
def update(self, messages:List[OnceConversation]) -> None:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?", "UPDATE chat_history set messages=? where conv_uid=?",
@ -161,7 +160,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return {} return {}
def get_messages(self) -> List[OnceConversation]: def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute(

View File

@ -110,7 +110,6 @@ async def db_connect_delete(db_name: str = None):
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo]) @router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types(): async def db_support_types():
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types() support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
db_type_infos = [] db_type_infos = []
for type in support_types: for type in support_types:
@ -130,7 +129,7 @@ async def dialogue_list(user_id: str = None):
chat_mode = item.get("chat_mode") chat_mode = item.get("chat_mode")
messages = json.loads(item.get("messages")) messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x['chat_order']) last_round = max(messages, key=lambda x: x["chat_order"])
if "param_value" in last_round: if "param_value" in last_round:
select_param = last_round["param_value"] select_param = last_round["param_value"]
else: else:
@ -139,7 +138,7 @@ async def dialogue_list(user_id: str = None):
conv_uid=conv_uid, conv_uid=conv_uid,
user_input=summary, user_input=summary,
chat_mode=chat_mode, chat_mode=chat_mode,
select_param=select_param select_param=select_param,
) )
dialogues.append(conv_vo) dialogues.append(conv_vo)
@ -213,7 +212,9 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
), ),
) )
## chat prepare ## chat prepare
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename) dialogue = ConversationVo(
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
)
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = get_chat_instance(dialogue)
resp = chat.prepare() resp = chat.prepare()
@ -229,7 +230,8 @@ async def dialogue_delete(con_uid: str):
history_mem.delete() history_mem.delete()
return Result.succ(None) return Result.succ(None)
def get_hist_messages(conv_uid:str):
def get_hist_messages(conv_uid: str):
message_vos: List[MessageVo] = [] message_vos: List[MessageVo] = []
history_mem = DuckdbHistoryMemory(conv_uid) history_mem = DuckdbHistoryMemory(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
@ -264,7 +266,7 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
chat_param = { chat_param = {
"chat_session_id": dialogue.conv_uid, "chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input, "user_input": dialogue.user_input,
"select_param": dialogue.select_param "select_param": dialogue.select_param,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
return chat return chat

View File

@ -22,7 +22,7 @@ from pilot.openapi.editor_view_model import (
ChartDetail, ChartDetail,
ChatChartEditContext, ChatChartEditContext,
ChatSqlEditContext, ChatSqlEditContext,
DbTable DbTable,
) )
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
@ -38,7 +38,9 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
@router.get("/v1/editor/db/tables", response_model=Result[DbTable]) @router.get("/v1/editor/db/tables", response_model=Result[DbTable])
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""): async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
):
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}") logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
tables = db_conn.get_table_names() tables = db_conn.get_table_names()
@ -49,8 +51,15 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc
fields = db_conn.get_fields(table) fields = db_conn.get_fields(table)
for field in fields: for field in fields:
table_node.children.append( table_node.children.append(
DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3], DataNode(
comment=field[-1])) title=field[0],
key=field[0],
type=field[1],
default_value=field[2],
can_null=field[3],
comment=field[-1],
)
)
return Result.succ(db_node) return Result.succ(db_node)
@ -68,8 +77,11 @@ async def get_editor_sql_rounds(con_uid: str):
if element["type"] == "human": if element["type"] == "human":
round_name = element["data"]["content"] round_name = element["data"]["content"]
if once.get("param_value"): if once.get("param_value"):
round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"], round: ChatDbRounds = ChatDbRounds(
round_name=round_name) round=once["chat_order"],
db_name=once["param_value"],
round_name=round_name,
)
result.append(round) result.append(round)
return Result.succ(result) return Result.succ(result)
@ -84,8 +96,14 @@ async def get_editor_sql(con_uid: str, round: int):
if int(once["chat_order"]) == round: if int(once["chat_order"]) == round:
for element in once["messages"]: for element in once["messages"]:
if element["type"] == "ai": if element["type"] == "ai":
logger.info(f'history ai json resp:{element["data"]["content"]}') logger.info(
context = element["data"]["content"].replace("\\n", " ").replace("\n", " ") f'history ai json resp:{element["data"]["content"]}'
)
context = (
element["data"]["content"]
.replace("\\n", " ")
.replace("\n", " ")
)
return Result.succ(json.loads(context)) return Result.succ(json.loads(context))
return Result.faild(msg="not have sql!") return Result.faild(msg="not have sql!")
@ -93,8 +111,8 @@ async def get_editor_sql(con_uid: str, round: int):
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData]) @router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
async def editor_sql_run(run_param: dict = Body()): async def editor_sql_run(run_param: dict = Body()):
logger.info(f"editor_sql_run:{run_param}") logger.info(f"editor_sql_run:{run_param}")
db_name = run_param['db_name'] db_name = run_param["db_name"]
sql = run_param['sql'] sql = run_param["sql"]
if not db_name and not sql: if not db_name and not sql:
return Result.faild("SQL run param error") return Result.faild("SQL run param error")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
@ -104,18 +122,17 @@ async def editor_sql_run(run_param: dict = Body()):
colunms, sql_result = conn.query_ex(sql) colunms, sql_result = conn.query_ex(sql)
# 计算执行耗时 # 计算执行耗时
end_time = time.time() * 1000 end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(result_info="", sql_run_data: SqlRunData = SqlRunData(
run_cost=(end_time - start_time) / 1000, result_info="",
colunms=colunms, run_cost=(end_time - start_time) / 1000,
values=sql_result colunms=colunms,
) values=sql_result,
)
return Result.succ(sql_run_data) return Result.succ(sql_run_data)
except Exception as e: except Exception as e:
return Result.succ(SqlRunData(result_info=str(e), return Result.succ(
run_cost=0, SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
colunms=[], )
values=[]
))
@router.post("/v1/sql/editor/submit") @router.post("/v1/sql/editor/submit")
@ -126,18 +143,24 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
if history_messages: if history_messages:
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0] edit_round = list(
filter(
lambda x: x["chat_order"] == sql_edit_context.conv_round,
history_messages,
)
)[0]
if edit_round: if edit_round:
for element in edit_round["messages"]: for element in edit_round["messages"]:
if element["type"] == "ai": if element["type"] == "ai":
db_resp = json.loads(element["data"]["content"]) db_resp = json.loads(element["data"]["content"])
db_resp['thoughts'] = sql_edit_context.new_speak db_resp["thoughts"] = sql_edit_context.new_speak
db_resp['sql'] = sql_edit_context.new_sql db_resp["sql"] = sql_edit_context.new_sql
element["data"]["content"] = json.dumps(db_resp) element["data"]["content"] = json.dumps(db_resp)
if element["type"] == "view": if element["type"] == "view":
data_loader = DbDataLoader() data_loader = DbDataLoader()
element["data"]["content"] = data_loader.get_table_view_by_conn(conn.run(sql_edit_context.new_sql), element["data"]["content"] = data_loader.get_table_view_by_conn(
sql_edit_context.new_speak) conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
)
history_mem.update(history_messages) history_mem.update(history_messages)
return Result.succ(None) return Result.succ(None)
return Result.faild(msg="Edit Faild!") return Result.faild(msg="Edit Faild!")
@ -145,16 +168,21 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
@router.get("/v1/editor/chart/list", response_model=Result[ChartList]) @router.get("/v1/editor/chart/list", response_model=Result[ChartList])
async def get_editor_chart_list(con_uid: str): async def get_editor_chart_list(con_uid: str):
logger.info(f"get_editor_sql_rounds:{con_uid}", ) logger.info(
f"get_editor_sql_rounds:{con_uid}",
)
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order']) last_round = max(history_messages, key=lambda x: x["chat_order"])
db_name = last_round["param_value"] db_name = last_round["param_value"]
for element in last_round["messages"]: for element in last_round["messages"]:
if element["type"] == "ai": if element["type"] == "ai":
chart_list: ChartList = ChartList(round=last_round['chat_order'], db_name=db_name, chart_list: ChartList = ChartList(
charts=json.loads(element["data"]["content"])) round=last_round["chat_order"],
db_name=db_name,
charts=json.loads(element["data"]["content"]),
)
return Result.succ(chart_list) return Result.succ(chart_list)
return Result.faild(msg="Not have charts!") return Result.faild(msg="Not have charts!")
@ -162,33 +190,40 @@ async def get_editor_chart_list(con_uid: str):
@router.post("/v1/editor/chart/info", response_model=Result[ChartDetail]) @router.post("/v1/editor/chart/info", response_model=Result[ChartDetail])
async def get_editor_chart_info(param: dict = Body()): async def get_editor_chart_info(param: dict = Body()):
logger.info(f"get_editor_chart_info:{param}") logger.info(f"get_editor_chart_info:{param}")
conv_uid = param['con_uid'] conv_uid = param["con_uid"]
chart_title = param['chart_title'] chart_title = param["chart_title"]
history_mem = DuckdbHistoryMemory(conv_uid) history_mem = DuckdbHistoryMemory(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order']) last_round = max(history_messages, key=lambda x: x["chat_order"])
db_name = last_round["param_value"] db_name = last_round["param_value"]
if not db_name: if not db_name:
logger.error("this dashboard dialogue version too old, can't support editor!") logger.error(
return Result.faild(msg="this dashboard dialogue version too old, can't support editor!") "this dashboard dialogue version too old, can't support editor!"
)
return Result.faild(
msg="this dashboard dialogue version too old, can't support editor!"
)
for element in last_round["messages"]: for element in last_round["messages"]:
if element["type"] == "view": if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]); view_data: dict = json.loads(element["data"]["content"])
charts: List = view_data.get("charts") charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_title, charts))[0] find_chart = list(
filter(lambda x: x["chart_name"] == chart_title, charts)
)[0]
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'], detail: ChartDetail = ChartDetail(
chart_type=find_chart['chart_type'], chart_uid=find_chart["chart_uid"],
chart_desc=find_chart['chart_desc'], chart_type=find_chart["chart_type"],
chart_sql=find_chart['chart_sql'], chart_desc=find_chart["chart_desc"],
db_name=db_name, chart_sql=find_chart["chart_sql"],
chart_name=find_chart['chart_name'], db_name=db_name,
chart_value=find_chart['values'], chart_name=find_chart["chart_name"],
table_value=conn.run(find_chart['chart_sql']) chart_value=find_chart["values"],
) table_value=conn.run(find_chart["chart_sql"]),
)
return Result.succ(detail) return Result.succ(detail)
return Result.faild(msg="Can't Find Chart Detail Info!") return Result.faild(msg="Can't Find Chart Detail Info!")
@ -197,36 +232,39 @@ async def get_editor_chart_info(param: dict = Body()):
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(run_param: dict = Body()): async def editor_chart_run(run_param: dict = Body()):
logger.info(f"editor_chart_run:{run_param}") logger.info(f"editor_chart_run:{run_param}")
db_name = run_param['db_name'] db_name = run_param["db_name"]
sql = run_param['sql'] sql = run_param["sql"]
chart_type = run_param['chart_type'] chart_type = run_param["chart_type"]
if not db_name and not sql: if not db_name and not sql:
return Result.faild("SQL run param error") return Result.faild("SQL run param error")
try: try:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
colunms, sql_result = db_conn.query_ex(sql) colunms, sql_result = db_conn.query_ex(sql)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(colunms, sql_result, sql) field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms, sql_result, sql
)
start_time = time.time() * 1000 start_time = time.time() * 1000
# 计算执行耗时 # 计算执行耗时
end_time = time.time() * 1000 end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(result_info="", sql_run_data: SqlRunData = SqlRunData(
run_cost=(end_time - start_time) / 1000, result_info="",
colunms=colunms, run_cost=(end_time - start_time) / 1000,
values=sql_result colunms=colunms,
) values=sql_result,
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values, chart_type = chart_type)) )
return Result.succ(
ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
)
)
except Exception as e: except Exception as e:
sql_result = SqlRunData(result_info=str(e), sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
run_cost=0, return Result.succ(
colunms=[], ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
values=[] )
)
return Result.succ(ChartRunData(sql_data = sql_result,
chart_values=[],
chart_type = chart_type
))
@router.post("/v1/chart/editor/submit", response_model=Result[bool]) @router.post("/v1/chart/editor/submit", response_model=Result[bool])
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()): async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
@ -237,35 +275,53 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
edit_round = max(history_messages, key=lambda x: x['chat_order']) edit_round = max(history_messages, key=lambda x: x["chat_order"])
if edit_round: if edit_round:
try: try:
for element in edit_round["messages"]: for element in edit_round["messages"]:
if element["type"] == "view": if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]); view_data: dict = json.loads(element["data"]["content"])
charts: List = view_data.get("charts") charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_title, charts))[ find_chart = list(
0] filter(
lambda x: x["chart_name"]
== chart_edit_context.chart_title,
charts,
)
)[0]
if chart_edit_context.new_chart_type: if chart_edit_context.new_chart_type:
find_chart['chart_type'] = chart_edit_context.new_chart_type find_chart["chart_type"] = chart_edit_context.new_chart_type
if chart_edit_context.new_comment: if chart_edit_context.new_comment:
find_chart['chart_desc'] = chart_edit_context.new_comment find_chart["chart_desc"] = chart_edit_context.new_comment
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, (
chart_edit_context.new_sql) field_names,
find_chart['chart_sql'] = chart_edit_context.new_sql chart_values,
find_chart['values'] = [value.dict() for value in chart_values] ) = dashboard_data_loader.get_chart_values_by_conn(
find_chart['column_name'] = field_names db_conn, chart_edit_context.new_sql
)
find_chart["chart_sql"] = chart_edit_context.new_sql
find_chart["values"] = [value.dict() for value in chart_values]
find_chart["column_name"] = field_names
element["data"]["content"] = json.dumps(view_data, ensure_ascii=False) element["data"]["content"] = json.dumps(
view_data, ensure_ascii=False
)
if element["type"] == "ai": if element["type"] == "ai":
ai_resp: dict = json.loads(element["data"]["content"]) ai_resp: dict = json.loads(element["data"]["content"])
edit_item = list(filter(lambda x: x['title'] == chart_edit_context.chart_title, ai_resp))[0] edit_item = list(
filter(
lambda x: x["title"] == chart_edit_context.chart_title,
ai_resp,
)
)[0]
edit_item["sql"] = chart_edit_context.new_sql edit_item["sql"] = chart_edit_context.new_sql
edit_item["showcase"] = chart_edit_context.new_chart_type edit_item["showcase"] = chart_edit_context.new_chart_type
edit_item["thoughts"] = chart_edit_context.new_comment edit_item["thoughts"] = chart_edit_context.new_comment
element["data"]["content"] = json.dumps(ai_resp, ensure_ascii=False) element["data"]["content"] = json.dumps(
ai_resp, ensure_ascii=False
)
except Exception as e: except Exception as e:
logger.error(f"edit chart exception!{str(e)}", e) logger.error(f"edit chart exception!{str(e)}", e)
return Result.faild(msg=f"Edit chart exception!{str(e)}") return Result.faild(msg=f"Edit chart exception!{str(e)}")

View File

@ -9,7 +9,7 @@ class DataNode(BaseModel):
type: str = "" type: str = ""
default_value: str = None default_value: str = None
can_null: str = 'YES' can_null: str = "YES"
comment: str = None comment: str = None
children: List = [] children: List = []

View File

@ -10,11 +10,13 @@ class DbField(BaseModel):
default_value: str = "" default_value: str = ""
comment: str = "" comment: str = ""
class DbTable(BaseModel): class DbTable(BaseModel):
table_name: str table_name: str
comment: str comment: str
colunm: List[DbField] colunm: List[DbField]
class ChatDbRounds(BaseModel): class ChatDbRounds(BaseModel):
round: int round: int
db_name: str db_name: str
@ -61,4 +63,3 @@ class ChatSqlEditContext(BaseModel):
new_sql: str new_sql: str
new_speak: str = "" new_speak: str = ""

View File

@ -123,11 +123,7 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\*", "*") ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "") ai_response = ai_response.replace("\t", "")
ai_response = ( ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
ai_response.strip()
.replace("\\n", " ")
.replace("\n", " ")
)
print("un_stream ai response:", ai_response) print("un_stream ai response:", ai_response)
return ai_response return ai_response
else: else:
@ -209,9 +205,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output) cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = ( cleaned_output = (
cleaned_output.strip() cleaned_output.strip()
.replace("\\n", " ") .replace("\\n", " ")
.replace("\n", " ") .replace("\n", " ")
.replace("\\", " ") .replace("\\", " ")
) )
cleaned_output = self.__illegal_json_ends(cleaned_output) cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output return cleaned_output

View File

@ -12,7 +12,6 @@ class Scene:
is_inner: bool = False, is_inner: bool = False,
show_disable=False, show_disable=False,
prepare_scene_code: str = None, prepare_scene_code: str = None,
): ):
self.code = code self.code = code
self.name = name self.name = name
@ -22,38 +21,39 @@ class Scene:
self.show_disable = show_disable self.show_disable = show_disable
self.prepare_scene_code = prepare_scene_code self.prepare_scene_code = prepare_scene_code
class ChatScene(Enum): class ChatScene(Enum):
ChatWithDbExecute = Scene( ChatWithDbExecute = Scene(
code = "chat_with_db_execute", code="chat_with_db_execute",
name = "Chat Data", name="Chat Data",
describe = "Dialogue with your private data through natural language.", describe="Dialogue with your private data through natural language.",
param_types = ["DB Select"], param_types=["DB Select"],
) )
ExcelLearning = Scene( ExcelLearning = Scene(
code = "excel_learning", code="excel_learning",
name = "Excel Learning", name="Excel Learning",
describe = "Analyze and summarize your excel files.", describe="Analyze and summarize your excel files.",
is_inner = True, is_inner=True,
) )
ChatExcel = Scene( ChatExcel = Scene(
code = "chat_excel", code="chat_excel",
name = "Chat Excel", name="Chat Excel",
describe = "Dialogue with your excel, use natural language.", describe="Dialogue with your excel, use natural language.",
param_types=["File Select"], param_types=["File Select"],
prepare_scene_code="excel_learning" prepare_scene_code="excel_learning",
) )
ChatWithDbQA = Scene( ChatWithDbQA = Scene(
code = "chat_with_db_qa", code="chat_with_db_qa",
name = "Chat DB", name="Chat DB",
describe = "Have a Professional Conversation with Metadata.", describe="Have a Professional Conversation with Metadata.",
param_types = ["DB Select"], param_types=["DB Select"],
) )
ChatExecution = Scene( ChatExecution = Scene(
code = "chat_execution", code="chat_execution",
name = "Use Plugin", name="Use Plugin",
describe = "Use tools through dialogue to accomplish your goals.", describe="Use tools through dialogue to accomplish your goals.",
param_types = ["Plugin Select"], param_types=["Plugin Select"],
) )
InnerChatDBSummary = Scene( InnerChatDBSummary = Scene(
@ -78,7 +78,7 @@ class ChatScene(Enum):
@staticmethod @staticmethod
def of_mode(mode): def of_mode(mode):
return [x for x in ChatScene._value_ if x.code == mode][0] return [x for x in ChatScene._value_ if x.code == mode][0]
@staticmethod @staticmethod
def is_valid_mode(mode): def is_valid_mode(mode):
@ -100,4 +100,4 @@ class ChatScene(Enum):
return self._value_.show_disable return self._value_.show_disable
def is_inner(self): def is_inner(self):
return self._value_.is_inner return self._value_.is_inner

View File

@ -60,11 +60,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
chat_mode,
chat_session_id,
current_user_input,
select_param: Any = None
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
@ -72,7 +68,6 @@ class BaseChat(ABC):
self.llm_model = CFG.LLM_MODEL self.llm_model = CFG.LLM_MODEL
self.llm_echo = False self.llm_echo = False
### load prompt template ### load prompt template
# self.prompt_template: PromptTemplate = CFG.prompt_templates[ # self.prompt_template: PromptTemplate = CFG.prompt_templates[
# self.chat_mode.value() # self.chat_mode.value()
@ -182,6 +177,7 @@ class BaseChat(ABC):
return response return response
else: else:
from pilot.server.llmserver import worker from pilot.server.llmserver import worker
return worker.generate_stream_gate(payload) return worker.generate_stream_gate(payload)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -235,7 +231,9 @@ class BaseChat(ABC):
### llm speaker ### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response) speak_to_user = self.get_llm_speak(prompt_define_response)
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
)
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -253,10 +251,8 @@ class BaseChat(ABC):
else: else:
return self.nostream_call() return self.nostream_call()
def prepare(self): def prepare(self):
pass pass
def generate_llm_text(self) -> str: def generate_llm_text(self) -> str:
warnings.warn("This method is deprecated - please use `generate_llm_messages`.") warnings.warn("This method is deprecated - please use `generate_llm_messages`.")
@ -363,9 +359,7 @@ class BaseChat(ABC):
) )
if len(self.history_message) > self.chat_retention_rounds: if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]["messages"]: for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [ if not first_message["type"] in [ModelMessageRoleType.VIEW]:
ModelMessageRoleType.VIEW
]:
message_type = first_message["type"] message_type = first_message["type"]
message_content = first_message["data"]["content"] message_content = first_message["data"]["content"]
history_text += ( history_text += (
@ -394,7 +388,9 @@ class BaseChat(ABC):
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append( history_messages.append(
ModelMessage(role=message_type, content=message_content) ModelMessage(
role=message_type, content=message_content
)
) )
else: else:

View File

@ -63,6 +63,7 @@ class AIMessage(BaseMessage):
class ViewMessage(BaseMessage): class ViewMessage(BaseMessage):
"""Type of message that is spoken by the AI.""" """Type of message that is spoken by the AI."""
example: bool = False example: bool = False
@property @property
@ -73,6 +74,7 @@ class ViewMessage(BaseMessage):
class SystemMessage(BaseMessage): class SystemMessage(BaseMessage):
"""Type of message that is a system message.""" """Type of message that is a system message."""
@property @property
def type(self) -> str: def type(self) -> str:
"""Type of the message, used for serialization.""" """Type of the message, used for serialization."""

View File

@ -21,9 +21,15 @@ class ChatDashboard(BaseChat):
report_name: str report_name: str
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param:str = "", report_name:str="report"): def __init__(
self,
chat_session_id,
user_input,
select_param: str = "",
report_name: str = "report",
):
""" """ """ """
self.db_name=select_param self.db_name = select_param
super().__init__( super().__init__(
chat_mode=ChatScene.ChatDashboard, chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
@ -80,7 +86,9 @@ class ChatDashboard(BaseChat):
dashboard_data_loader = DashboardDataLoader() dashboard_data_loader = DashboardDataLoader()
for chart_item in prompt_response: for chart_item in prompt_response:
try: try:
field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.database, chart_item.sql) field_names, values = dashboard_data_loader.get_chart_values_by_conn(
self.database, chart_item.sql
)
chart_datas.append( chart_datas.append(
ChartData( ChartData(
chart_uid=str(uuid.uuid1()), chart_uid=str(uuid.uuid1()),
@ -101,5 +109,3 @@ class ChatDashboard(BaseChat):
template_introduce=None, template_introduce=None,
charts=chart_datas, charts=chart_datas,
) )

View File

@ -11,15 +11,14 @@ logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
class DashboardDataLoader: class DashboardDataLoader:
def get_sql_value(self, db_conn, chart_sql: str): def get_sql_value(self, db_conn, chart_sql: str):
return db_conn.query_ex(chart_sql) return db_conn.query_ex(chart_sql)
def get_chart_values_by_conn(self, db_conn, chart_sql: str) : def get_chart_values_by_conn(self, db_conn, chart_sql: str):
field_names, datas = db_conn.query_ex(chart_sql) field_names, datas = db_conn.query_ex(chart_sql)
return self.get_chart_values_by_data(field_names, datas, chart_sql) return self.get_chart_values_by_data(field_names, datas, chart_sql)
def get_chart_values_by_data(self, field_names, datas, chart_sql: str) : def get_chart_values_by_data(self, field_names, datas, chart_sql: str):
logger.info(f"get_chart_values_by_conn:{chart_sql}") logger.info(f"get_chart_values_by_conn:{chart_sql}")
try: try:
values: List[ValueItem] = [] values: List[ValueItem] = []
@ -57,7 +56,7 @@ class DashboardDataLoader:
logger.debug("Prepare Chart Data Faild!" + str(e)) logger.debug("Prepare Chart Data Faild!" + str(e))
raise ValueError("Prepare Chart Data Faild!") raise ValueError("Prepare Chart Data Faild!")
def get_chart_values_by_db(self, db_name: str, chart_sql: str) : def get_chart_values_by_db(self, db_name: str, chart_sql: str):
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}") logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
return self.get_chart_values_by_conn(db_conn, chart_sql) return self.get_chart_values_by_conn(db_conn, chart_sql)

View File

@ -27,6 +27,7 @@ CFG = Config()
class ChatExcel(BaseChat): class ChatExcel(BaseChat):
chat_scene: str = ChatScene.ChatExcel.value() chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 1 chat_retention_rounds = 1
def __init__(self, chat_session_id, user_input, select_param: str = ""): def __init__(self, chat_session_id, user_input, select_param: str = ""):
chat_mode = ChatScene.ChatExcel chat_mode = ChatScene.ChatExcel
@ -34,9 +35,11 @@ class ChatExcel(BaseChat):
if has_path(select_param): if has_path(select_param):
self.excel_reader = ExcelReader(select_param) self.excel_reader = ExcelReader(select_param)
else: else:
self.excel_reader = ExcelReader(os.path.join( self.excel_reader = ExcelReader(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param os.path.join(
)) KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
)
)
super().__init__( super().__init__(
chat_mode=chat_mode, chat_mode=chat_mode,
@ -70,12 +73,9 @@ class ChatExcel(BaseChat):
] ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {
"user_input": self.current_user_input, "user_input": self.current_user_input,
"table_name": self.excel_reader.table_name, "table_name": self.excel_reader.table_name,
"disply_type": self._generate_numbered_list(), "disply_type": self._generate_numbered_list(),
} }
@ -87,23 +87,21 @@ class ChatExcel(BaseChat):
return None return None
chat_param = { chat_param = {
"chat_session_id": self.chat_session_id, "chat_session_id": self.chat_session_id,
"user_input": "[" + self.excel_reader.excel_file_name +"]" + " Analysis", "user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analysis",
"parent_mode": self.chat_mode, "parent_mode": self.chat_mode,
"select_param":self.excel_reader.excel_file_name, "select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader "excel_reader": self.excel_reader,
} }
learn_chat = ExcelLearning(**chat_param) learn_chat = ExcelLearning(**chat_param)
result = learn_chat.nostream_call() result = learn_chat.nostream_call()
return result return result
def do_action(self, prompt_response): def do_action(self, prompt_response):
print(f"do_action:{prompt_response}") print(f"do_action:{prompt_response}")
# colunms, datas = self.excel_reader.run(prompt_response.sql) # colunms, datas = self.excel_reader.run(prompt_response.sql)
param= { param = {
"speak": prompt_response.thoughts, "speak": prompt_response.thoughts,
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql) "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
} }
return CFG.command_disply.call(prompt_response.display, **param) return CFG.command_disply.call(prompt_response.display, **param)

View File

@ -2,7 +2,9 @@ import json
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import ChatExcelOutputParser from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import (
ChatExcelOutputParser,
)
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
@ -39,9 +41,9 @@ SQL中需要使用的表名是: {table_name}
""" """
RESPONSE_FORMAT_SIMPLE = { RESPONSE_FORMAT_SIMPLE = {
"sql": "analysis SQL", "sql": "analysis SQL",
"thoughts": "Current thinking and value of data analysis", "thoughts": "Current thinking and value of data analysis",
"display": "display type name" "display": "display type name",
} }
_DEFAULT_TEMPLATE = ( _DEFAULT_TEMPLATE = (
@ -67,9 +69,8 @@ prompt = PromptTemplate(
output_parser=ChatExcelOutputParser( output_parser=ChatExcelOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
), ),
need_historical_messages = True, need_historical_messages=True,
# example_selector=sql_data_example, # example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE, temperature=PROMPT_TEMPERATURE,
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -23,7 +23,14 @@ CFG = Config()
class ExcelLearning(BaseChat): class ExcelLearning(BaseChat):
chat_scene: str = ChatScene.ExcelLearning.value() chat_scene: str = ChatScene.ExcelLearning.value()
def __init__(self, chat_session_id, user_input, parent_mode: Any=None, select_param:str=None, excel_reader:Any=None): def __init__(
self,
chat_session_id,
user_input,
parent_mode: Any = None,
select_param: str = None,
excel_reader: Any = None,
):
chat_mode = ChatScene.ExcelLearning chat_mode = ChatScene.ExcelLearning
""" """ """ """
self.excel_file_path = select_param self.excel_file_path = select_param
@ -31,20 +38,19 @@ class ExcelLearning(BaseChat):
super().__init__( super().__init__(
chat_mode=chat_mode, chat_mode=chat_mode,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
current_user_input = user_input, current_user_input=user_input,
select_param=select_param, select_param=select_param,
) )
if parent_mode: if parent_mode:
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()
def generate_input_values(self): def generate_input_values(self):
colunms, datas = self.excel_reader.get_sample_data() colunms, datas = self.excel_reader.get_sample_data()
datas.insert(0, colunms) datas.insert(0, colunms)
input_values = { input_values = {
"data_example": json.dumps(self.excel_reader.get_sample_data(), cls=DateTimeEncoder), "data_example": json.dumps(
self.excel_reader.get_sample_data(), cls=DateTimeEncoder
),
} }
return input_values return input_values

View File

@ -37,23 +37,23 @@ class LearningExcelOutputParser(BaseOutputParser):
plans = response[key] plans = response[key]
return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans) return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans)
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:
### tool out data to table view ### tool out data to table view
html_title = f"### **数据简介**\n{data.desciption} " html_title = f"### **数据简介**\n{data.desciption} "
html_colunms = f"### **数据结构**\n" html_colunms = f"### **数据结构**\n"
column_index = 0 column_index = 0
for item in data.clounms: for item in data.clounms:
column_index +=1 column_index += 1
keys = item.keys() keys = item.keys()
for key in keys: for key in keys:
html_colunms = html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" html_colunms = (
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
)
html_plans = f"### **分析计划**\n" html_plans = f"### **分析计划**\n"
index = 0 index = 0
for item in data.plans: for item in data.plans:
index +=1 index += 1
html_plans = html_plans + f"{item} \n" html_plans = html_plans + f"{item} \n"
html = f"""{html_title}\n{html_colunms}\n{html_plans}""" html = f"""{html_title}\n{html_colunms}\n{html_plans}"""
return html return html

View File

@ -2,7 +2,9 @@ import json
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import LearningExcelOutputParser from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import (
LearningExcelOutputParser,
)
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
@ -29,7 +31,7 @@ _DEFAULT_TEMPLATE_ZH = """
{response} {response}
""" """
RESPONSE_FORMAT_SIMPLE = { RESPONSE_FORMAT_SIMPLE = {
"DataAnalysis": "数据内容分析总结", "DataAnalysis": "数据内容分析总结",
"ColumnAnalysis": [{"column name1": "字段1介绍专业术语解释(请尽量简单明了)"}], "ColumnAnalysis": [{"column name1": "字段1介绍专业术语解释(请尽量简单明了)"}],
"AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"], "AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
@ -63,5 +65,3 @@ prompt = PromptTemplate(
temperature=PROMPT_TEMPERATURE, temperature=PROMPT_TEMPERATURE,
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -7,7 +7,7 @@ import seaborn as sns
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import time import time
from fsspec import filesystem from fsspec import filesystem
import spatial import spatial
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
@ -18,22 +18,35 @@ if __name__ == "__main__":
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
# colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear") # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """) colunms, datas = excel_reader.run(
df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;") """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """
)
df = excel_reader.get_df_by_sql_ex(
"SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;"
)
columns = df.columns.tolist() columns = df.columns.tolist()
plt.rcParams["font.family"] = ["sans-serif"] plt.rcParams["font.family"] = ["sans-serif"]
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) sns.set_style(rc={"font.sans-serif": "Microsoft Yahei"})
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
plt.subplots_adjust(top=0.9) plt.subplots_adjust(top=0.9)
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
# 手动设置 labels 的位置和大小 # 手动设置 labels 的位置和大小
ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10) ax.legend(
plt.axis('equal') # 使饼图为正圆形 loc="center left", bbox_to_anchor=(-1, 0.5, 0, 0), labels=None, fontsize=10
)
plt.axis("equal") # 使饼图为正圆形
plt.show() plt.show()
# #
# #

View File

@ -7,11 +7,13 @@ import numpy as np
from pilot.common.pd_utils import csv_colunm_foramt from pilot.common.pd_utils import csv_colunm_foramt
def excel_colunm_format(old_name:str)->str:
def excel_colunm_format(old_name: str) -> str:
new_column = old_name.strip() new_column = old_name.strip()
new_column = new_column.replace(" ", "_") new_column = new_column.replace(" ", "_")
return new_column return new_column
def add_quotes(sql, column_names=[]): def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "") sql = sql.replace("`", "")
parsed = sqlparse.parse(sql) parsed = sqlparse.parse(sql)
@ -20,44 +22,51 @@ def add_quotes(sql, column_names=[]):
deep_quotes(token, column_names) deep_quotes(token, column_names)
return str(parsed[0]) return str(parsed[0])
def deep_quotes(token, column_names=[]): def deep_quotes(token, column_names=[]):
if hasattr(token, "tokens") : if hasattr(token, "tokens"):
for token_child in token.tokens: for token_child in token.tokens:
deep_quotes(token_child, column_names) deep_quotes(token_child, column_names)
else: else:
if token.ttype == sqlparse.tokens.Name: if token.ttype == sqlparse.tokens.Name:
if len(column_names) >0: if len(column_names) > 0:
if token.value in column_names: if token.value in column_names:
token.value = f'"{token.value.replace("`", "")}"' token.value = f'"{token.value.replace("`", "")}"'
else: else:
token.value = f'"{token.value.replace("`", "")}"' token.value = f'"{token.value.replace("`", "")}"'
def is_chinese(string): def is_chinese(string):
# 使用正则表达式匹配中文字符 # 使用正则表达式匹配中文字符
pattern = re.compile(r'[一-龥]') pattern = re.compile(r"[一-龥]")
match = re.search(pattern, string) match = re.search(pattern, string)
return match is not None return match is not None
class ExcelReader: class ExcelReader:
def __init__(self, file_path): def __init__(self, file_path):
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
file_name_without_extension = os.path.splitext(file_name)[0] file_name_without_extension = os.path.splitext(file_name)[0]
self.excel_file_name = file_name self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1] self.extension = os.path.splitext(file_name)[1]
# read excel file # read excel file
if file_path.endswith('.xlsx') or file_path.endswith('.xls'): if file_path.endswith(".xlsx") or file_path.endswith(".xls"):
df_tmp = pd.read_excel(file_path) df_tmp = pd.read_excel(file_path)
self.df = pd.read_excel(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) self.df = pd.read_excel(
elif file_path.endswith('.csv'): file_path,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
elif file_path.endswith(".csv"):
df_tmp = pd.read_csv(file_path) df_tmp = pd.read_csv(file_path)
self.df = pd.read_csv(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) self.df = pd.read_csv(
file_path,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
else: else:
raise ValueError("Unsupported file format.") raise ValueError("Unsupported file format.")
self.df.replace('', np.nan, inplace=True) self.df.replace("", np.nan, inplace=True)
self.columns_map = {} self.columns_map = {}
for column_name in df_tmp.columns: for column_name in df_tmp.columns:
self.columns_map.update({column_name: excel_colunm_format(column_name)}) self.columns_map.update({column_name: excel_colunm_format(column_name)})
@ -66,18 +75,17 @@ class ExcelReader:
except Exception as e: except Exception as e:
print("transfor column error" + column_name) print("transfor column error" + column_name)
self.df = self.df.rename(columns=lambda x: x.strip().replace(' ', '_')) self.df = self.df.rename(columns=lambda x: x.strip().replace(" ", "_"))
# connect DuckDB # connect DuckDB
self.db = duckdb.connect(database=':memory:', read_only=False) self.db = duckdb.connect(database=":memory:", read_only=False)
self.table_name = file_name_without_extension self.table_name = file_name_without_extension
# write data in duckdb # write data in duckdb
self.db.register(self.table_name, self.df) self.db.register(self.table_name, self.df)
def run(self, sql): def run(self, sql):
if f'"{self.table_name}"' not in sql: if f'"{self.table_name}"' not in sql:
sql = sql.replace(self.table_name, f'"{self.table_name}"') sql = sql.replace(self.table_name, f'"{self.table_name}"')
sql = add_quotes(sql, self.columns_map.values()) sql = add_quotes(sql, self.columns_map.values())
print(f"excute sql:{sql}") print(f"excute sql:{sql}")
@ -92,5 +100,4 @@ class ExcelReader:
return pd.DataFrame(values, columns=colunms) return pd.DataFrame(values, columns=colunms)
def get_sample_data(self): def get_sample_data(self):
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;') return self.run(f"SELECT * FROM {self.table_name} LIMIT 5;")

View File

@ -21,7 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param:str = ""): def __init__(self, chat_session_id, user_input, select_param: str = ""):
chat_mode = ChatScene.ChatWithDbExecute chat_mode = ChatScene.ChatWithDbExecute
self.db_name = select_param self.db_name = select_param
""" """ """ """
@ -47,7 +47,9 @@ class ChatWithDbAutoExecute(BaseChat):
client = DBSummaryClient() client = DBSummaryClient()
try: try:
table_infos = client.get_db_summary( table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE dbname=self.db_name,
query=self.current_user_input,
topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
) )
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
@ -65,4 +67,4 @@ class ChatWithDbAutoExecute(BaseChat):
def do_action(self, prompt_response): def do_action(self, prompt_response):
print(f"do_action:{prompt_response}") print(f"do_action:{prompt_response}")
return self.database.run( prompt_response.sql) return self.database.run(prompt_response.sql)

View File

@ -8,6 +8,7 @@ from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.data_loader import DbDataLoader from pilot.scene.chat_db.data_loader import DbDataLoader
CFG = Config() CFG = Config()
@ -49,6 +50,4 @@ class DbChatOutputParser(BaseOutputParser):
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text return view_text
else: else:
return data_loader.get_table_view_by_conn(data, speak) return data_loader.get_table_view_by_conn(data, speak)

View File

@ -1,8 +1,7 @@
import pandas as pd import pandas as pd
class DbDataLoader: class DbDataLoader:
def get_table_view_by_conn(self, data, speak): def get_table_view_by_conn(self, data, speak):
### tool out data to table view ### tool out data to table view
if len(data) <= 1: if len(data) <= 1:
@ -12,4 +11,4 @@ class DbDataLoader:
table_str = "".join(html_table.split()) table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>""" html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text return view_text

View File

@ -19,7 +19,7 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param:str = ""): def __init__(self, chat_session_id, user_input, select_param: str = ""):
""" """ """ """
self.db_name = select_param self.db_name = select_param
super().__init__( super().__init__(

View File

@ -20,12 +20,7 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator: PluginPromptGenerator plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None select_plugin: str = None
def __init__( def __init__(self, chat_session_id, user_input, select_param: str = None):
self,
chat_session_id,
user_input,
select_param: str = None
):
self.plugin_selector = select_param self.plugin_selector = select_param
super().__init__( super().__init__(
chat_mode=ChatScene.ChatExecution, chat_mode=ChatScene.ChatExecution,

View File

@ -30,7 +30,7 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None): def __init__(self, chat_session_id, user_input, select_param: str = None):
""" """ """ """
self.knowledge_space = select_param self.knowledge_space = select_param
super().__init__( super().__init__(

View File

@ -30,7 +30,6 @@ class ChatNormal(BaseChat):
input_values = {"input": self.current_user_input} input_values = {"input": self.current_user_input}
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatNormal.value return ChatScene.ChatNormal.value

View File

@ -117,12 +117,10 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"tokens": once.tokens if once.tokens else 0, "tokens": once.tokens if once.tokens else 0,
"messages": messages_to_dict(once.messages), "messages": messages_to_dict(once.messages),
"param_type": once.param_type, "param_type": once.param_type,
"param_value": once.param_value "param_value": once.param_value,
} }
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
return [_conversation_to_dic(m) for m in conversations] return [_conversation_to_dic(m) for m in conversations]

View File

@ -64,16 +64,12 @@ def server_init(args):
cfg.command_registry = command_registry cfg.command_registry = command_registry
command_disply_commands = [ command_disply_commands = [
"pilot.commands.disply_type.show_chart_gen", "pilot.commands.disply_type.show_chart_gen",
"pilot.commands.disply_type.show_table_gen", "pilot.commands.disply_type.show_table_gen",
"pilot.commands.disply_type.show_text_gen", "pilot.commands.disply_type.show_text_gen",
] ]
command_disply_registry = CommandRegistry() command_disply_registry = CommandRegistry()
for command in command_disply_commands: for command in command_disply_commands:
command_disply_registry.import_commands(command) command_disply_registry.import_commands(command)
cfg.command_disply = command_disply_registry cfg.command_disply = command_disply_registry

View File

@ -81,7 +81,9 @@ app.include_router(knowledge_router)
# app.include_router(api_editor_route_v1) # app.include_router(api_editor_route_v1)
os.makedirs(static_message_img_path, exist_ok=True) os.makedirs(static_message_img_path, exist_ok=True)
app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2") app.mount(
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
)
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static")) app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")

View File

@ -19,7 +19,6 @@ from pilot.utils import build_logger
logger = build_logger("db_summary", LOGDIR + "db_summary.log") logger = build_logger("db_summary", LOGDIR + "db_summary.log")
CFG = Config() CFG = Config()
chat_factory = ChatFactory() chat_factory = ChatFactory()
@ -145,7 +144,9 @@ class DBSummaryClient:
try: try:
self.db_summary_embedding(item["db_name"], item["db_type"]) self.db_summary_embedding(item["db_name"], item["db_type"])
except Exception as e: except Exception as e:
logger.warn(f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e) logger.warn(
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e
)
def init_db_profile(self, db_summary_client, dbname, embeddings): def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = { profile_store_config = {

View File

@ -75,7 +75,6 @@ def build_logger(logger_name, logger_filename):
item.addHandler(handler) item.addHandler(handler)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# Get logger # Get logger
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)