mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 15:10:14 +00:00
feat(editor): ChatExcel
ChatExcel devlop part 1
This commit is contained in:
56
pilot/commands/built_in/show_chart_gen.py
Normal file
56
pilot/commands/built_in/show_chart_gen.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from pilot.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
import pandas as pd
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import matplotlib
|
||||||
|
import seaborn as sns
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||||
|
|
||||||
|
|
||||||
|
@command("response_line_chart", "Use line chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||||
|
def response_line_chart(speak: str, sql: str, db_name: str) -> str:
|
||||||
|
logger.info(f"response_line_chart:{speak},{sql},{db_name}")
|
||||||
|
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
|
||||||
|
columns = df.columns.tolist()
|
||||||
|
|
||||||
|
if df.size <= 0:
|
||||||
|
raise ValueError("No Data!" + sql)
|
||||||
|
plt.rcParams["font.family"] = ["sans-serif"]
|
||||||
|
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
||||||
|
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
||||||
|
plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
|
sns.barplot(df, x=columns[0], y=columns[1])
|
||||||
|
plt.title("")
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
plt.savefig(buf, format="png", dpi=100)
|
||||||
|
buf.seek(0)
|
||||||
|
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
|
||||||
|
html_img = f"""<h5>{speak}</h5><img style='max-width: 120%; max-height: 80%;' src="data:image/png;base64,{data}" />"""
|
||||||
|
return html_img
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@command("response_bar_chart", "Use bar chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||||
|
def response_bar_chart(speak: str, sql: str, db_name: str) -> str:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@command("response_pie_chart", "Use pie chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||||
|
def response_pie_chart(speak: str, sql: str, db_name: str) -> str:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
pass
|
21
pilot/commands/built_in/show_table_gen.py
Normal file
21
pilot/commands/built_in/show_table_gen.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from pilot.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||||
|
|
||||||
|
|
||||||
|
@command("response_table", "Use table to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||||
|
def response_table(speak: str, sql: str, db_name: str) -> str:
|
||||||
|
logger.info(f"response_table:{speak},{sql},{db_name}")
|
||||||
|
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
|
||||||
|
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||||
|
table_str = "".join(html_table.split())
|
||||||
|
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||||
|
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||||
|
return view_text
|
33
pilot/commands/built_in/show_text_gen.py
Normal file
33
pilot/commands/built_in/show_text_gen.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from pilot.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||||
|
|
||||||
|
|
||||||
|
@command("response_data_text", "Use text to display SQL data",
|
||||||
|
'"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||||
|
def response_data_text(speak: str, sql: str, db_name: str) -> str:
|
||||||
|
logger.info(f"response_data_text:{speak},{sql},{db_name}")
|
||||||
|
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
|
||||||
|
data = df.values
|
||||||
|
|
||||||
|
row_size = data.shape[0]
|
||||||
|
value_str, text_info = ""
|
||||||
|
if row_size > 1:
|
||||||
|
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||||
|
table_str = "".join(html_table.split())
|
||||||
|
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||||
|
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||||
|
elif row_size == 1:
|
||||||
|
row = data[0]
|
||||||
|
for value in row:
|
||||||
|
value_str = value_str + f", ** {value} **"
|
||||||
|
text_info = f"{speak}: {value_str}"
|
||||||
|
else:
|
||||||
|
text_info = f"##### {speak}: _没有找到可用的数据_"
|
||||||
|
return text_info
|
@@ -39,7 +39,7 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
|||||||
|
|
||||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||||
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
|
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
|
||||||
logger.info("get_editor_tables:{},{},{},{}", db_name, page_index, page_size, search_str)
|
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
|
||||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
tables = db_conn.get_table_names()
|
tables = db_conn.get_table_names()
|
||||||
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
|
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
|
||||||
@@ -57,7 +57,7 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc
|
|||||||
|
|
||||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||||
async def get_editor_sql_rounds(con_uid: str):
|
async def get_editor_sql_rounds(con_uid: str):
|
||||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
logger.info("get_editor_sql_rounds:{con_uid}")
|
||||||
history_mem = DuckdbHistoryMemory(con_uid)
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
if history_messages:
|
if history_messages:
|
||||||
@@ -76,7 +76,7 @@ async def get_editor_sql_rounds(con_uid: str):
|
|||||||
|
|
||||||
@router.get("/v1/editor/sql", response_model=Result[dict])
|
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||||
async def get_editor_sql(con_uid: str, round: int):
|
async def get_editor_sql(con_uid: str, round: int):
|
||||||
logger.info("get_editor_sql:{},{}", con_uid, round)
|
logger.info(f"get_editor_sql:{con_uid},{round}")
|
||||||
history_mem = DuckdbHistoryMemory(con_uid)
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
if history_messages:
|
if history_messages:
|
||||||
@@ -90,13 +90,14 @@ async def get_editor_sql(con_uid: str, round: int):
|
|||||||
|
|
||||||
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
||||||
async def editor_sql_run(run_param: dict = Body()):
|
async def editor_sql_run(run_param: dict = Body()):
|
||||||
logger.info("editor_sql_run:{}", run_param)
|
logger.info(f"editor_sql_run:{run_param}")
|
||||||
db_name = run_param['db_name']
|
db_name = run_param['db_name']
|
||||||
sql = run_param['sql']
|
sql = run_param['sql']
|
||||||
if not db_name and not sql:
|
if not db_name and not sql:
|
||||||
return Result.faild("SQL run param error!")
|
return Result.faild("SQL run param error!")
|
||||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
|
|
||||||
|
try:
|
||||||
start_time = time.time() * 1000
|
start_time = time.time() * 1000
|
||||||
colunms, sql_result = conn.query_ex(sql)
|
colunms, sql_result = conn.query_ex(sql)
|
||||||
# 计算执行耗时
|
# 计算执行耗时
|
||||||
@@ -107,9 +108,15 @@ async def editor_sql_run(run_param: dict = Body()):
|
|||||||
values=sql_result
|
values=sql_result
|
||||||
)
|
)
|
||||||
return Result.succ(sql_run_data)
|
return Result.succ(sql_run_data)
|
||||||
|
except Exception as e:
|
||||||
|
return Result.succ(SqlRunData(result_info=str(e),
|
||||||
|
run_cost=0,
|
||||||
|
colunms=[],
|
||||||
|
values=[]
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/sql/editor/submit", response_model=Result)
|
@router.post("/v1/sql/editor/submit")
|
||||||
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||||
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||||
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
|
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
|
||||||
@@ -136,7 +143,7 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
|||||||
|
|
||||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
|
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
|
||||||
async def get_editor_chart_list(con_uid: str):
|
async def get_editor_chart_list(con_uid: str):
|
||||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
logger.info(f"get_editor_sql_rounds:{con_uid}", )
|
||||||
history_mem = DuckdbHistoryMemory(con_uid)
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
if history_messages:
|
if history_messages:
|
||||||
@@ -152,8 +159,7 @@ async def get_editor_chart_list(con_uid: str):
|
|||||||
|
|
||||||
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||||
async def get_editor_chart_info(con_uid: str, chart_title: str):
|
async def get_editor_chart_info(con_uid: str, chart_title: str):
|
||||||
logger.info(f"get_editor_sql_rounds:{con_uid}")
|
logger.info(f"get_editor_chart_info:{con_uid},{chart_title}")
|
||||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
|
||||||
history_mem = DuckdbHistoryMemory(con_uid)
|
history_mem = DuckdbHistoryMemory(con_uid)
|
||||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||||
if history_messages:
|
if history_messages:
|
||||||
@@ -184,13 +190,18 @@ async def get_editor_chart_info(con_uid: str, chart_title: str):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||||
async def editor_chart_run(db_name: str, sql: str):
|
async def editor_chart_run(run_param: dict = Body()):
|
||||||
logger.info(f"editor_chart_run:{db_name},{sql}")
|
logger.info(f"editor_chart_run:{run_param}")
|
||||||
|
db_name = run_param['db_name']
|
||||||
|
sql = run_param['sql']
|
||||||
|
if not db_name and not sql:
|
||||||
|
return Result.faild("SQL run param error!")
|
||||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
|
|
||||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, sql)
|
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, sql)
|
||||||
|
|
||||||
|
try:
|
||||||
start_time = time.time() * 1000
|
start_time = time.time() * 1000
|
||||||
colunms, sql_result = db_conn.query_ex(sql)
|
colunms, sql_result = db_conn.query_ex(sql)
|
||||||
# 计算执行耗时
|
# 计算执行耗时
|
||||||
@@ -201,7 +212,12 @@ async def editor_chart_run(db_name: str, sql: str):
|
|||||||
values=sql_result
|
values=sql_result
|
||||||
)
|
)
|
||||||
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values))
|
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values))
|
||||||
|
except Exception as e:
|
||||||
|
return Result.succ(SqlRunData(result_info=str(e),
|
||||||
|
run_cost=0,
|
||||||
|
colunms=[],
|
||||||
|
values=[]
|
||||||
|
))
|
||||||
|
|
||||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||||
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||||
|
@@ -57,8 +57,8 @@ class ChatSqlEditContext(BaseModel):
|
|||||||
|
|
||||||
old_sql: str
|
old_sql: str
|
||||||
old_speak: str
|
old_speak: str
|
||||||
gmt_create: int
|
gmt_create: int = 0
|
||||||
|
|
||||||
new_sql: str
|
new_sql: str
|
||||||
new_speak: str
|
new_speak: str = ""
|
||||||
new_view_info: str
|
new_view_info: str = ""
|
||||||
|
@@ -226,13 +226,13 @@ class BaseOutputParser(ABC):
|
|||||||
"""Instructions on how the LLM output should be formatted."""
|
"""Instructions on how the LLM output should be formatted."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
# @property
|
||||||
def _type(self) -> str:
|
# def _type(self) -> str:
|
||||||
"""Return the type key."""
|
# """Return the type key."""
|
||||||
raise NotImplementedError(
|
# raise NotImplementedError(
|
||||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
# f"_type property is not implemented in class {self.__class__.__name__}."
|
||||||
" This is required for serialization."
|
# " This is required for serialization."
|
||||||
)
|
# )
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of output parser."""
|
"""Return dictionary representation of output parser."""
|
||||||
|
@@ -27,6 +27,21 @@ class ChatScene(Enum):
|
|||||||
"Dialogue with your private data through natural language.",
|
"Dialogue with your private data through natural language.",
|
||||||
["DB Select"],
|
["DB Select"],
|
||||||
)
|
)
|
||||||
|
ExcelLearning = Scene(
|
||||||
|
"excel_learning",
|
||||||
|
"Excel Learning",
|
||||||
|
"Analyze and summarize your excel files.",
|
||||||
|
[],
|
||||||
|
True,
|
||||||
|
True
|
||||||
|
)
|
||||||
|
ChatExcel = Scene(
|
||||||
|
"chat_excel",
|
||||||
|
"Chat Excel",
|
||||||
|
"Dialogue with your excel, use natural language.",
|
||||||
|
["File Select"],
|
||||||
|
)
|
||||||
|
|
||||||
ChatWithDbQA = Scene(
|
ChatWithDbQA = Scene(
|
||||||
"chat_with_db_qa",
|
"chat_with_db_qa",
|
||||||
"Chat DB",
|
"Chat DB",
|
||||||
|
51
pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py
Normal file
51
pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import json
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
This is an example data,please learn to understand the structure and content of this data:
|
||||||
|
{data_example}
|
||||||
|
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
|
||||||
|
Provide some analysis options,please think step by step.
|
||||||
|
|
||||||
|
Please return your answer in JSON format, the return format is as follows:
|
||||||
|
{response}
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
|
"Data Analysis": "数据内容分析总结",
|
||||||
|
"Colunm Analysis": [{"colunm name": "字段介绍,专业术语解释(请尽量简单明了)"}],
|
||||||
|
"Analysis Program": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||||
|
}
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||||
|
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||||
|
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||||
|
PROMPT_TEMPERATURE = 0.5
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatExcel.value(),
|
||||||
|
input_variables=["data_example"],
|
||||||
|
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=DbChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
# example_selector=sql_data_example,
|
||||||
|
temperature=PROMPT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
|
|
39
pilot/scene/chat_data/chat_excel/excel_learning/chat.py
Normal file
39
pilot/scene/chat_data/chat_excel/excel_learning/chat.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from pilot.scene.base_message import (
|
||||||
|
HumanMessage,
|
||||||
|
ViewMessage,
|
||||||
|
)
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.sql_database import Database
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.common.markdown_text import (
|
||||||
|
generate_htm_table,
|
||||||
|
)
|
||||||
|
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelLearning(BaseChat):
|
||||||
|
chat_scene: str = ChatScene.ExcelLearning.value()
|
||||||
|
|
||||||
|
def __init__(self, chat_session_id, file_path):
|
||||||
|
chat_mode = ChatScene.ChatWithDbExecute
|
||||||
|
""" """
|
||||||
|
super().__init__(
|
||||||
|
chat_mode=chat_mode,
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
select_param=file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_input_values(self):
|
||||||
|
|
||||||
|
input_values = {
|
||||||
|
"data_example": "",
|
||||||
|
}
|
||||||
|
return input_values
|
||||||
|
|
||||||
|
|
@@ -0,0 +1,65 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, NamedTuple, List
|
||||||
|
import pandas as pd
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelResponse(NamedTuple):
|
||||||
|
desciption: str
|
||||||
|
clounms: List
|
||||||
|
plans: List
|
||||||
|
|
||||||
|
|
||||||
|
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatExcelOutputParser(BaseOutputParser):
|
||||||
|
def __init__(self, sep: str, is_stream_out: bool):
|
||||||
|
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||||
|
|
||||||
|
def parse_prompt_response(self, model_out_text):
|
||||||
|
clean_str = super().parse_prompt_response(model_out_text)
|
||||||
|
print("clean prompt response:", clean_str)
|
||||||
|
response = json.loads(clean_str)
|
||||||
|
for key in sorted(response):
|
||||||
|
if key.strip() == "Data Analysis":
|
||||||
|
desciption = response[key]
|
||||||
|
if key.strip() == "Column Analysis":
|
||||||
|
clounms = response[key]
|
||||||
|
if key.strip() == "Analysis Program":
|
||||||
|
plans = response[key]
|
||||||
|
return ExcelResponse(desciption=desciption, clounms=clounms,plans=plans)
|
||||||
|
|
||||||
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
### tool out data to table view
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
html_title= data["desciption"]
|
||||||
|
html_colunms= f"<h5>数据结构</h5><ul>"
|
||||||
|
for item in data["clounms"]:
|
||||||
|
html_colunms = html_colunms + "<li>"
|
||||||
|
keys = item.keys()
|
||||||
|
for key in keys:
|
||||||
|
html_colunms = html_colunms + f"{key}:{item[key]}"
|
||||||
|
html_colunms = html_colunms + "</li>"
|
||||||
|
html_colunms= html_colunms + "</ul>"
|
||||||
|
|
||||||
|
html_plans="<ol>"
|
||||||
|
for item in data["plans"]:
|
||||||
|
html_plans = html_plans + f"<li>{item}</li>"
|
||||||
|
html = f"""
|
||||||
|
<div>
|
||||||
|
<h4>{html_title}</h4>
|
||||||
|
<div>{html_colunms}</div>
|
||||||
|
<div>{html_plans}</div>
|
||||||
|
<div>
|
||||||
|
"""
|
||||||
|
return html
|
51
pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
Normal file
51
pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import json
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
This is an example data,please learn to understand the structure and content of this data:
|
||||||
|
{data_example}
|
||||||
|
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
|
||||||
|
Provide some analysis options,please think step by step.
|
||||||
|
|
||||||
|
Please return your answer in JSON format, the return format is as follows:
|
||||||
|
{response}
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
|
"Data Analysis": "数据内容分析总结",
|
||||||
|
"Colunm Analysis": [{"colunm name": "字段介绍,专业术语解释(请尽量简单明了)"}],
|
||||||
|
"Analysis Program": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||||
|
}
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||||
|
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||||
|
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||||
|
PROMPT_TEMPERATURE = 0.5
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ExcelLearning.value(),
|
||||||
|
input_variables=["data_example"],
|
||||||
|
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=DbChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
# example_selector=sql_data_example,
|
||||||
|
temperature=PROMPT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
|
|
84
pilot/scene/chat_data/chat_excel/excel_learning/test.py
Normal file
84
pilot/scene/chat_data/chat_excel/excel_learning/test.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
import duckdb
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import time
|
||||||
|
from fsspec import filesystem
|
||||||
|
import spatial
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# connect = duckdb.connect("/Users/tuyang.yhj/Downloads/example.xlsx")
|
||||||
|
#
|
||||||
|
|
||||||
|
def csv_colunm_foramt(val):
|
||||||
|
if str(val).find("$") >= 0:
|
||||||
|
return float(val.replace('$', '').replace(',', ''))
|
||||||
|
if str(val).find("¥") >= 0:
|
||||||
|
return float(val.replace('¥', '').replace(',', ''))
|
||||||
|
return val
|
||||||
|
|
||||||
|
# 获取当前时间戳,作为代码开始的时间
|
||||||
|
start_time = int(time.time() * 1000)
|
||||||
|
|
||||||
|
df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx')
|
||||||
|
# 读取 Excel 文件为 Pandas DataFrame
|
||||||
|
df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx', converters={i: csv_colunm_foramt for i in range(df.shape[1])})
|
||||||
|
|
||||||
|
d = df.values
|
||||||
|
print(d.shape[0])
|
||||||
|
for row in d:
|
||||||
|
print(row[0])
|
||||||
|
print(len(row))
|
||||||
|
r = df.iterrows()
|
||||||
|
|
||||||
|
# 获取当前时间戳,作为代码结束的时间
|
||||||
|
end_time = int(time.time() * 1000)
|
||||||
|
|
||||||
|
print(f"耗时:{(end_time-start_time)/1000}秒")
|
||||||
|
|
||||||
|
# 连接 DuckDB 数据库
|
||||||
|
con = duckdb.connect(database=':memory:', read_only=False)
|
||||||
|
|
||||||
|
# 将 DataFrame 写入 DuckDB 数据库中的一个表
|
||||||
|
con.register('example', df)
|
||||||
|
|
||||||
|
# 查询 DuckDB 数据库中的表
|
||||||
|
conn = con.cursor()
|
||||||
|
results = con.execute('SELECT * FROM example limit 5 ')
|
||||||
|
colunms = []
|
||||||
|
for descrip in results.description:
|
||||||
|
colunms.append(descrip[0])
|
||||||
|
print(colunms)
|
||||||
|
for row in results.fetchall():
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
|
||||||
|
# 连接 DuckDB 数据库
|
||||||
|
# con = duckdb.connect(':memory:')
|
||||||
|
|
||||||
|
# # 加载 spatial 扩展
|
||||||
|
# con.execute('install spatial;')
|
||||||
|
# con.execute('load spatial;')
|
||||||
|
#
|
||||||
|
# # 查询 duckdb_internal 系统表,获取扩展列表
|
||||||
|
# result = con.execute("SELECT * FROM duckdb_internal.functions WHERE schema='list_extensions';")
|
||||||
|
#
|
||||||
|
# # 遍历查询结果,输出扩展名称和版本号
|
||||||
|
# for row in result:
|
||||||
|
# print(row['name'], row['return_type'])
|
||||||
|
# duckdb.read_csv('/Users/tuyang.yhj/Downloads/example_csc.csv')
|
||||||
|
# result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/yhj-zx.csv" ')
|
||||||
|
# result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/example_csc.csv" limit 20')
|
||||||
|
# for row in result.fetchall():
|
||||||
|
# print(row)
|
||||||
|
|
||||||
|
|
||||||
|
# result = con.execute("SELECT * FROM st_read('/Users/tuyang.yhj/Downloads/example.xlsx', layer='Sheet1')")
|
||||||
|
# # 遍历查询结果
|
||||||
|
# for row in result.fetchall():
|
||||||
|
# print(row)
|
||||||
|
print("xx")
|
||||||
|
|
||||||
|
|
||||||
|
|
@@ -52,7 +52,3 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
return data_loader.get_table_view_by_conn(data, speak)
|
return data_loader.get_table_view_by_conn(data, speak)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
return "sql_chat"
|
|
||||||
|
Reference in New Issue
Block a user