feat(editor): ChatExcel

ChatExcel devlop part 1
This commit is contained in:
yhjun1026
2023-08-15 16:24:09 +08:00
parent 0bd99d7764
commit 44225e9aea
13 changed files with 472 additions and 45 deletions

View File

@@ -0,0 +1,56 @@
from pilot.commands.command_mange import command
from pilot.configs.config import Config
import pandas as pd
import base64
import io
import matplotlib
import seaborn as sns
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
@command("response_line_chart", "Use line chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
def response_line_chart(speak: str, sql: str, db_name: str) -> str:
logger.info(f"response_line_chart:{speak},{sql},{db_name}")
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data" + sql)
plt.rcParams["font.family"] = ["sans-serif"]
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
plt.subplots(figsize=(8, 5), dpi=100)
sns.barplot(df, x=columns[0], y=columns[1])
plt.title("")
buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=100)
buf.seek(0)
data = base64.b64encode(buf.getvalue()).decode("ascii")
html_img = f"""<h5>{speak}</h5><img style='max-width: 120%; max-height: 80%;' src="data:image/png;base64,{data}" />"""
return html_img
@command("response_bar_chart", "Use bar chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
def response_bar_chart(speak: str, sql: str, db_name: str) -> str:
"""
"""
pass
@command("response_pie_chart", "Use pie chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
def response_pie_chart(speak: str, sql: str, db_name: str) -> str:
"""
"""
pass

View File

@@ -0,0 +1,21 @@
import pandas as pd
from pilot.commands.command_mange import command
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_table", "Use table to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
def response_table(speak: str, sql: str, db_name: str) -> str:
logger.info(f"response_table:{speak},{sql},{db_name}")
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text

View File

@@ -0,0 +1,33 @@
import pandas as pd
from pilot.commands.command_mange import command
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_data_text", "Use text to display SQL data",
'"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
def response_data_text(speak: str, sql: str, db_name: str) -> str:
logger.info(f"response_data_text:{speak},{sql},{db_name}")
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
data = df.values
row_size = data.shape[0]
value_str, text_info = ""
if row_size > 1:
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
elif row_size == 1:
row = data[0]
for value in row:
value_str = value_str + f", ** {value} **"
text_info = f"{speak}: {value_str}"
else:
text_info = f"##### {speak}: _没有找到可用的数据_"
return text_info

View File

@@ -39,7 +39,7 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
@router.get("/v1/editor/db/tables", response_model=Result[DbTable]) @router.get("/v1/editor/db/tables", response_model=Result[DbTable])
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""): async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
logger.info("get_editor_tables:{},{},{},{}", db_name, page_index, page_size, search_str) logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
tables = db_conn.get_table_names() tables = db_conn.get_table_names()
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db") db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
@@ -57,7 +57,7 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds]) @router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
async def get_editor_sql_rounds(con_uid: str): async def get_editor_sql_rounds(con_uid: str):
logger.info("get_editor_sql_rounds:{}", con_uid) logger.info("get_editor_sql_rounds:{con_uid}")
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
@@ -76,7 +76,7 @@ async def get_editor_sql_rounds(con_uid: str):
@router.get("/v1/editor/sql", response_model=Result[dict]) @router.get("/v1/editor/sql", response_model=Result[dict])
async def get_editor_sql(con_uid: str, round: int): async def get_editor_sql(con_uid: str, round: int):
logger.info("get_editor_sql:{},{}", con_uid, round) logger.info(f"get_editor_sql:{con_uid},{round}")
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
@@ -90,26 +90,33 @@ async def get_editor_sql(con_uid: str, round: int):
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData]) @router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
async def editor_sql_run(run_param: dict = Body()): async def editor_sql_run(run_param: dict = Body()):
logger.info("editor_sql_run:{}", run_param) logger.info(f"editor_sql_run:{run_param}")
db_name = run_param['db_name'] db_name = run_param['db_name']
sql = run_param['sql'] sql = run_param['sql']
if not db_name and not sql: if not db_name and not sql:
return Result.faild("SQL run param error") return Result.faild("SQL run param error")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
start_time = time.time() * 1000 try:
colunms, sql_result = conn.query_ex(sql) start_time = time.time() * 1000
# 计算执行耗时 colunms, sql_result = conn.query_ex(sql)
end_time = time.time() * 1000 # 计算执行耗时
sql_run_data: SqlRunData = SqlRunData(result_info="", end_time = time.time() * 1000
run_cost=(end_time - start_time) / 1000, sql_run_data: SqlRunData = SqlRunData(result_info="",
colunms=colunms, run_cost=(end_time - start_time) / 1000,
values=sql_result colunms=colunms,
) values=sql_result
return Result.succ(sql_run_data) )
return Result.succ(sql_run_data)
except Exception as e:
return Result.succ(SqlRunData(result_info=str(e),
run_cost=0,
colunms=[],
values=[]
))
@router.post("/v1/sql/editor/submit", response_model=Result) @router.post("/v1/sql/editor/submit")
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}") logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid) history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
@@ -136,7 +143,7 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
@router.get("/v1/editor/chart/list", response_model=Result[ChartList]) @router.get("/v1/editor/chart/list", response_model=Result[ChartList])
async def get_editor_chart_list(con_uid: str): async def get_editor_chart_list(con_uid: str):
logger.info("get_editor_sql_rounds:{}", con_uid) logger.info(f"get_editor_sql_rounds:{con_uid}", )
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
@@ -152,8 +159,7 @@ async def get_editor_chart_list(con_uid: str):
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail]) @router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
async def get_editor_chart_info(con_uid: str, chart_title: str): async def get_editor_chart_info(con_uid: str, chart_title: str):
logger.info(f"get_editor_sql_rounds:{con_uid}") logger.info(f"get_editor_chart_info:{con_uid},{chart_title}")
logger.info("get_editor_sql_rounds:{}", con_uid)
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
@@ -184,24 +190,34 @@ async def get_editor_chart_info(con_uid: str, chart_title: str):
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(db_name: str, sql: str): async def editor_chart_run(run_param: dict = Body()):
logger.info(f"editor_chart_run:{db_name},{sql}") logger.info(f"editor_chart_run:{run_param}")
db_name = run_param['db_name']
sql = run_param['sql']
if not db_name and not sql:
return Result.faild("SQL run param error")
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, sql) field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, sql)
start_time = time.time() * 1000 try:
colunms, sql_result = db_conn.query_ex(sql) start_time = time.time() * 1000
# 计算执行耗时 colunms, sql_result = db_conn.query_ex(sql)
end_time = time.time() * 1000 # 计算执行耗时
sql_run_data: SqlRunData = SqlRunData(result_info="", end_time = time.time() * 1000
run_cost=(end_time - start_time) / 1000, sql_run_data: SqlRunData = SqlRunData(result_info="",
colunms=colunms, run_cost=(end_time - start_time) / 1000,
values=sql_result colunms=colunms,
) values=sql_result
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values)) )
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values))
except Exception as e:
return Result.succ(SqlRunData(result_info=str(e),
run_cost=0,
colunms=[],
values=[]
))
@router.post("/v1/chart/editor/submit", response_model=Result[bool]) @router.post("/v1/chart/editor/submit", response_model=Result[bool])
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()): async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):

View File

@@ -57,8 +57,8 @@ class ChatSqlEditContext(BaseModel):
old_sql: str old_sql: str
old_speak: str old_speak: str
gmt_create: int gmt_create: int = 0
new_sql: str new_sql: str
new_speak: str new_speak: str = ""
new_view_info: str new_view_info: str = ""

View File

@@ -226,13 +226,13 @@ class BaseOutputParser(ABC):
"""Instructions on how the LLM output should be formatted.""" """Instructions on how the LLM output should be formatted."""
raise NotImplementedError raise NotImplementedError
@property # @property
def _type(self) -> str: # def _type(self) -> str:
"""Return the type key.""" # """Return the type key."""
raise NotImplementedError( # raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}." # f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization." # " This is required for serialization."
) # )
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser.""" """Return dictionary representation of output parser."""

View File

@@ -27,6 +27,21 @@ class ChatScene(Enum):
"Dialogue with your private data through natural language.", "Dialogue with your private data through natural language.",
["DB Select"], ["DB Select"],
) )
ExcelLearning = Scene(
"excel_learning",
"Excel Learning",
"Analyze and summarize your excel files.",
[],
True,
True
)
ChatExcel = Scene(
"chat_excel",
"Chat Excel",
"Dialogue with your excel, use natural language.",
["File Select"],
)
ChatWithDbQA = Scene( ChatWithDbQA = Scene(
"chat_with_db_qa", "chat_with_db_qa",
"Chat DB", "Chat DB",

View File

@@ -0,0 +1,51 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
_DEFAULT_TEMPLATE = """
This is an example dataplease learn to understand the structure and content of this data:
{data_example}
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
Provide some analysis options,please think step by step.
Please return your answer in JSON format, the return format is as follows:
{response}
"""
RESPONSE_FORMAT_SIMPLE = {
"Data Analysis": "数据内容分析总结",
"Colunm Analysis": [{"colunm name": "字段介绍,专业术语解释(请尽量简单明了)"}],
"Analysis Program": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatExcel.value(),
input_variables=["data_example"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -0,0 +1,39 @@
import json
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
)
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
CFG = Config()
class ExcelLearning(BaseChat):
chat_scene: str = ChatScene.ExcelLearning.value()
def __init__(self, chat_session_id, file_path):
chat_mode = ChatScene.ChatWithDbExecute
""" """
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
select_param=file_path,
)
def generate_input_values(self):
input_values = {
"data_example": "",
}
return input_values

View File

@@ -0,0 +1,65 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
CFG = Config()
class ExcelResponse(NamedTuple):
desciption: str
clounms: List
plans: List
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
class ChatExcelOutputParser(BaseOutputParser):
def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
for key in sorted(response):
if key.strip() == "Data Analysis":
desciption = response[key]
if key.strip() == "Column Analysis":
clounms = response[key]
if key.strip() == "Analysis Program":
plans = response[key]
return ExcelResponse(desciption=desciption, clounms=clounms,plans=plans)
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
html_title= data["desciption"]
html_colunms= f"<h5>数据结构</h5><ul>"
for item in data["clounms"]:
html_colunms = html_colunms + "<li>"
keys = item.keys()
for key in keys:
html_colunms = html_colunms + f"{key}:{item[key]}"
html_colunms = html_colunms + "</li>"
html_colunms= html_colunms + "</ul>"
html_plans="<ol>"
for item in data["plans"]:
html_plans = html_plans + f"<li>{item}</li>"
html = f"""
<div>
<h4>{html_title}</h4>
<div>{html_colunms}</div>
<div>{html_plans}</div>
<div>
"""
return html

View File

@@ -0,0 +1,51 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
_DEFAULT_TEMPLATE = """
This is an example dataplease learn to understand the structure and content of this data:
{data_example}
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
Provide some analysis options,please think step by step.
Please return your answer in JSON format, the return format is as follows:
{response}
"""
RESPONSE_FORMAT_SIMPLE = {
"Data Analysis": "数据内容分析总结",
"Colunm Analysis": [{"colunm name": "字段介绍,专业术语解释(请尽量简单明了)"}],
"Analysis Program": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ExcelLearning.value(),
input_variables=["data_example"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -0,0 +1,84 @@
import os
import duckdb
import pandas as pd
import time
from fsspec import filesystem
import spatial
if __name__ == "__main__":
# connect = duckdb.connect("/Users/tuyang.yhj/Downloads/example.xlsx")
#
def csv_colunm_foramt(val):
if str(val).find("$") >= 0:
return float(val.replace('$', '').replace(',', ''))
if str(val).find("¥") >= 0:
return float(val.replace('¥', '').replace(',', ''))
return val
# 获取当前时间戳,作为代码开始的时间
start_time = int(time.time() * 1000)
df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx')
# 读取 Excel 文件为 Pandas DataFrame
df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx', converters={i: csv_colunm_foramt for i in range(df.shape[1])})
d = df.values
print(d.shape[0])
for row in d:
print(row[0])
print(len(row))
r = df.iterrows()
# 获取当前时间戳,作为代码结束的时间
end_time = int(time.time() * 1000)
print(f"耗时:{(end_time-start_time)/1000}")
# 连接 DuckDB 数据库
con = duckdb.connect(database=':memory:', read_only=False)
# 将 DataFrame 写入 DuckDB 数据库中的一个表
con.register('example', df)
# 查询 DuckDB 数据库中的表
conn = con.cursor()
results = con.execute('SELECT * FROM example limit 5 ')
colunms = []
for descrip in results.description:
colunms.append(descrip[0])
print(colunms)
for row in results.fetchall():
print(row)
# 连接 DuckDB 数据库
# con = duckdb.connect(':memory:')
# # 加载 spatial 扩展
# con.execute('install spatial;')
# con.execute('load spatial;')
#
# # 查询 duckdb_internal 系统表,获取扩展列表
# result = con.execute("SELECT * FROM duckdb_internal.functions WHERE schema='list_extensions';")
#
# # 遍历查询结果,输出扩展名称和版本号
# for row in result:
# print(row['name'], row['return_type'])
# duckdb.read_csv('/Users/tuyang.yhj/Downloads/example_csc.csv')
# result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/yhj-zx.csv" ')
# result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/example_csc.csv" limit 20')
# for row in result.fetchall():
# print(row)
# result = con.execute("SELECT * FROM st_read('/Users/tuyang.yhj/Downloads/example.xlsx', layer='Sheet1')")
# # 遍历查询结果
# for row in result.fetchall():
# print(row)
print("xx")

View File

@@ -52,7 +52,3 @@ class DbChatOutputParser(BaseOutputParser):
return data_loader.get_table_view_by_conn(data, speak) return data_loader.get_table_view_by_conn(data, speak)
@property
def _type(self) -> str:
return "sql_chat"