mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 07:00:15 +00:00
feat(editor): ChatExcel
ChatExcel devlop part 1
This commit is contained in:
56
pilot/commands/built_in/show_chart_gen.py
Normal file
56
pilot/commands/built_in/show_chart_gen.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
import pandas as pd
|
||||
import base64
|
||||
import io
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||
|
||||
|
||||
@command("response_line_chart", "Use line chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||
def response_line_chart(speak: str, sql: str, db_name: str) -> str:
|
||||
logger.info(f"response_line_chart:{speak},{sql},{db_name}")
|
||||
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
|
||||
columns = df.columns.tolist()
|
||||
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!" + sql)
|
||||
plt.rcParams["font.family"] = ["sans-serif"]
|
||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
||||
plt.subplots(figsize=(8, 5), dpi=100)
|
||||
sns.barplot(df, x=columns[0], y=columns[1])
|
||||
plt.title("")
|
||||
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png", dpi=100)
|
||||
buf.seek(0)
|
||||
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 120%; max-height: 80%;' src="data:image/png;base64,{data}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
|
||||
@command("response_bar_chart", "Use bar chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||
def response_bar_chart(speak: str, sql: str, db_name: str) -> str:
|
||||
"""
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@command("response_pie_chart", "Use pie chart to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||
def response_pie_chart(speak: str, sql: str, db_name: str) -> str:
|
||||
"""
|
||||
"""
|
||||
pass
|
21
pilot/commands/built_in/show_table_gen.py
Normal file
21
pilot/commands/built_in/show_table_gen.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pandas as pd
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
|
||||
|
||||
@command("response_table", "Use table to display SQL data", '"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||
def response_table(speak: str, sql: str, db_name: str) -> str:
|
||||
logger.info(f"response_table:{speak},{sql},{db_name}")
|
||||
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
return view_text
|
33
pilot/commands/built_in/show_text_gen.py
Normal file
33
pilot/commands/built_in/show_text_gen.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import pandas as pd
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
|
||||
|
||||
@command("response_data_text", "Use text to display SQL data",
|
||||
'"speak": "<speak>", "sql":"<sql>","db_name":"<db_name>"')
|
||||
def response_data_text(speak: str, sql: str, db_name: str) -> str:
|
||||
logger.info(f"response_data_text:{speak},{sql},{db_name}")
|
||||
df = pd.read_sql(sql, CFG.LOCAL_DB_MANAGE.get_connect(db_name))
|
||||
data = df.values
|
||||
|
||||
row_size = data.shape[0]
|
||||
value_str, text_info = ""
|
||||
if row_size > 1:
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
elif row_size == 1:
|
||||
row = data[0]
|
||||
for value in row:
|
||||
value_str = value_str + f", ** {value} **"
|
||||
text_info = f"{speak}: {value_str}"
|
||||
else:
|
||||
text_info = f"##### {speak}: _没有找到可用的数据_"
|
||||
return text_info
|
@@ -39,7 +39,7 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
||||
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
|
||||
logger.info("get_editor_tables:{},{},{},{}", db_name, page_index, page_size, search_str)
|
||||
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
tables = db_conn.get_table_names()
|
||||
db_node: DataNode = DataNode(title=db_name, key=db_name, type="db")
|
||||
@@ -57,7 +57,7 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc
|
||||
|
||||
@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds])
|
||||
async def get_editor_sql_rounds(con_uid: str):
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
logger.info("get_editor_sql_rounds:{con_uid}")
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
@@ -76,7 +76,7 @@ async def get_editor_sql_rounds(con_uid: str):
|
||||
|
||||
@router.get("/v1/editor/sql", response_model=Result[dict])
|
||||
async def get_editor_sql(con_uid: str, round: int):
|
||||
logger.info("get_editor_sql:{},{}", con_uid, round)
|
||||
logger.info(f"get_editor_sql:{con_uid},{round}")
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
@@ -90,13 +90,14 @@ async def get_editor_sql(con_uid: str, round: int):
|
||||
|
||||
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
|
||||
async def editor_sql_run(run_param: dict = Body()):
|
||||
logger.info("editor_sql_run:{}", run_param)
|
||||
logger.info(f"editor_sql_run:{run_param}")
|
||||
db_name = run_param['db_name']
|
||||
sql = run_param['sql']
|
||||
if not db_name and not sql:
|
||||
return Result.faild("SQL run param error!")
|
||||
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
|
||||
try:
|
||||
start_time = time.time() * 1000
|
||||
colunms, sql_result = conn.query_ex(sql)
|
||||
# 计算执行耗时
|
||||
@@ -107,9 +108,15 @@ async def editor_sql_run(run_param: dict = Body()):
|
||||
values=sql_result
|
||||
)
|
||||
return Result.succ(sql_run_data)
|
||||
except Exception as e:
|
||||
return Result.succ(SqlRunData(result_info=str(e),
|
||||
run_cost=0,
|
||||
colunms=[],
|
||||
values=[]
|
||||
))
|
||||
|
||||
|
||||
@router.post("/v1/sql/editor/submit", response_model=Result)
|
||||
@router.post("/v1/sql/editor/submit")
|
||||
async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
logger.info(f"sql_editor_submit:{sql_edit_context.__dict__}")
|
||||
history_mem = DuckdbHistoryMemory(sql_edit_context.conv_uid)
|
||||
@@ -136,7 +143,7 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
||||
|
||||
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
|
||||
async def get_editor_chart_list(con_uid: str):
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
logger.info(f"get_editor_sql_rounds:{con_uid}", )
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
@@ -152,8 +159,7 @@ async def get_editor_chart_list(con_uid: str):
|
||||
|
||||
@router.get("/v1/editor/chart/info", response_model=Result[ChartDetail])
|
||||
async def get_editor_chart_info(con_uid: str, chart_title: str):
|
||||
logger.info(f"get_editor_sql_rounds:{con_uid}")
|
||||
logger.info("get_editor_sql_rounds:{}", con_uid)
|
||||
logger.info(f"get_editor_chart_info:{con_uid},{chart_title}")
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
@@ -184,13 +190,18 @@ async def get_editor_chart_info(con_uid: str, chart_title: str):
|
||||
|
||||
|
||||
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
|
||||
async def editor_chart_run(db_name: str, sql: str):
|
||||
logger.info(f"editor_chart_run:{db_name},{sql}")
|
||||
async def editor_chart_run(run_param: dict = Body()):
|
||||
logger.info(f"editor_chart_run:{run_param}")
|
||||
db_name = run_param['db_name']
|
||||
sql = run_param['sql']
|
||||
if not db_name and not sql:
|
||||
return Result.faild("SQL run param error!")
|
||||
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
|
||||
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
|
||||
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, sql)
|
||||
|
||||
try:
|
||||
start_time = time.time() * 1000
|
||||
colunms, sql_result = db_conn.query_ex(sql)
|
||||
# 计算执行耗时
|
||||
@@ -201,7 +212,12 @@ async def editor_chart_run(db_name: str, sql: str):
|
||||
values=sql_result
|
||||
)
|
||||
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values))
|
||||
|
||||
except Exception as e:
|
||||
return Result.succ(SqlRunData(result_info=str(e),
|
||||
run_cost=0,
|
||||
colunms=[],
|
||||
values=[]
|
||||
))
|
||||
|
||||
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
|
||||
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
|
||||
|
@@ -57,8 +57,8 @@ class ChatSqlEditContext(BaseModel):
|
||||
|
||||
old_sql: str
|
||||
old_speak: str
|
||||
gmt_create: int
|
||||
gmt_create: int = 0
|
||||
|
||||
new_sql: str
|
||||
new_speak: str
|
||||
new_view_info: str
|
||||
new_speak: str = ""
|
||||
new_view_info: str = ""
|
||||
|
@@ -226,13 +226,13 @@ class BaseOutputParser(ABC):
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the type key."""
|
||||
raise NotImplementedError(
|
||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
" This is required for serialization."
|
||||
)
|
||||
# @property
|
||||
# def _type(self) -> str:
|
||||
# """Return the type key."""
|
||||
# raise NotImplementedError(
|
||||
# f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
# " This is required for serialization."
|
||||
# )
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
|
@@ -27,6 +27,21 @@ class ChatScene(Enum):
|
||||
"Dialogue with your private data through natural language.",
|
||||
["DB Select"],
|
||||
)
|
||||
ExcelLearning = Scene(
|
||||
"excel_learning",
|
||||
"Excel Learning",
|
||||
"Analyze and summarize your excel files.",
|
||||
[],
|
||||
True,
|
||||
True
|
||||
)
|
||||
ChatExcel = Scene(
|
||||
"chat_excel",
|
||||
"Chat Excel",
|
||||
"Dialogue with your excel, use natural language.",
|
||||
["File Select"],
|
||||
)
|
||||
|
||||
ChatWithDbQA = Scene(
|
||||
"chat_with_db_qa",
|
||||
"Chat DB",
|
||||
|
51
pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py
Normal file
51
pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
This is an example data,please learn to understand the structure and content of this data:
|
||||
{data_example}
|
||||
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
|
||||
Provide some analysis options,please think step by step.
|
||||
|
||||
Please return your answer in JSON format, the return format is as follows:
|
||||
{response}
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"Data Analysis": "数据内容分析总结",
|
||||
"Colunm Analysis": [{"colunm name": "字段介绍,专业术语解释(请尽量简单明了)"}],
|
||||
"Analysis Program": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||
}
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||
PROMPT_TEMPERATURE = 0.5
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExcel.value(),
|
||||
input_variables=["data_example"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=DbChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
# example_selector=sql_data_example,
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||
|
39
pilot/scene/chat_data/chat_excel/excel_learning/chat.py
Normal file
39
pilot/scene/chat_data/chat_excel/excel_learning/chat.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ExcelLearning(BaseChat):
|
||||
chat_scene: str = ChatScene.ExcelLearning.value()
|
||||
|
||||
def __init__(self, chat_session_id, file_path):
|
||||
chat_mode = ChatScene.ChatWithDbExecute
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
select_param=file_path,
|
||||
)
|
||||
|
||||
|
||||
def generate_input_values(self):
|
||||
|
||||
input_values = {
|
||||
"data_example": "",
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
@@ -0,0 +1,65 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, NamedTuple, List
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ExcelResponse(NamedTuple):
|
||||
desciption: str
|
||||
clounms: List
|
||||
plans: List
|
||||
|
||||
|
||||
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
|
||||
|
||||
|
||||
class ChatExcelOutputParser(BaseOutputParser):
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
response = json.loads(clean_str)
|
||||
for key in sorted(response):
|
||||
if key.strip() == "Data Analysis":
|
||||
desciption = response[key]
|
||||
if key.strip() == "Column Analysis":
|
||||
clounms = response[key]
|
||||
if key.strip() == "Analysis Program":
|
||||
plans = response[key]
|
||||
return ExcelResponse(desciption=desciption, clounms=clounms,plans=plans)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
|
||||
|
||||
|
||||
html_title= data["desciption"]
|
||||
html_colunms= f"<h5>数据结构</h5><ul>"
|
||||
for item in data["clounms"]:
|
||||
html_colunms = html_colunms + "<li>"
|
||||
keys = item.keys()
|
||||
for key in keys:
|
||||
html_colunms = html_colunms + f"{key}:{item[key]}"
|
||||
html_colunms = html_colunms + "</li>"
|
||||
html_colunms= html_colunms + "</ul>"
|
||||
|
||||
html_plans="<ol>"
|
||||
for item in data["plans"]:
|
||||
html_plans = html_plans + f"<li>{item}</li>"
|
||||
html = f"""
|
||||
<div>
|
||||
<h4>{html_title}</h4>
|
||||
<div>{html_colunms}</div>
|
||||
<div>{html_plans}</div>
|
||||
<div>
|
||||
"""
|
||||
return html
|
51
pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
Normal file
51
pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
This is an example data,please learn to understand the structure and content of this data:
|
||||
{data_example}
|
||||
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
|
||||
Provide some analysis options,please think step by step.
|
||||
|
||||
Please return your answer in JSON format, the return format is as follows:
|
||||
{response}
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"Data Analysis": "数据内容分析总结",
|
||||
"Colunm Analysis": [{"colunm name": "字段介绍,专业术语解释(请尽量简单明了)"}],
|
||||
"Analysis Program": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||
}
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||
PROMPT_TEMPERATURE = 0.5
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExcelLearning.value(),
|
||||
input_variables=["data_example"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=DbChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
# example_selector=sql_data_example,
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||
|
84
pilot/scene/chat_data/chat_excel/excel_learning/test.py
Normal file
84
pilot/scene/chat_data/chat_excel/excel_learning/test.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
|
||||
import time
|
||||
from fsspec import filesystem
|
||||
import spatial
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# connect = duckdb.connect("/Users/tuyang.yhj/Downloads/example.xlsx")
|
||||
#
|
||||
|
||||
def csv_colunm_foramt(val):
|
||||
if str(val).find("$") >= 0:
|
||||
return float(val.replace('$', '').replace(',', ''))
|
||||
if str(val).find("¥") >= 0:
|
||||
return float(val.replace('¥', '').replace(',', ''))
|
||||
return val
|
||||
|
||||
# 获取当前时间戳,作为代码开始的时间
|
||||
start_time = int(time.time() * 1000)
|
||||
|
||||
df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx')
|
||||
# 读取 Excel 文件为 Pandas DataFrame
|
||||
df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx', converters={i: csv_colunm_foramt for i in range(df.shape[1])})
|
||||
|
||||
d = df.values
|
||||
print(d.shape[0])
|
||||
for row in d:
|
||||
print(row[0])
|
||||
print(len(row))
|
||||
r = df.iterrows()
|
||||
|
||||
# 获取当前时间戳,作为代码结束的时间
|
||||
end_time = int(time.time() * 1000)
|
||||
|
||||
print(f"耗时:{(end_time-start_time)/1000}秒")
|
||||
|
||||
# 连接 DuckDB 数据库
|
||||
con = duckdb.connect(database=':memory:', read_only=False)
|
||||
|
||||
# 将 DataFrame 写入 DuckDB 数据库中的一个表
|
||||
con.register('example', df)
|
||||
|
||||
# 查询 DuckDB 数据库中的表
|
||||
conn = con.cursor()
|
||||
results = con.execute('SELECT * FROM example limit 5 ')
|
||||
colunms = []
|
||||
for descrip in results.description:
|
||||
colunms.append(descrip[0])
|
||||
print(colunms)
|
||||
for row in results.fetchall():
|
||||
print(row)
|
||||
|
||||
|
||||
# 连接 DuckDB 数据库
|
||||
# con = duckdb.connect(':memory:')
|
||||
|
||||
# # 加载 spatial 扩展
|
||||
# con.execute('install spatial;')
|
||||
# con.execute('load spatial;')
|
||||
#
|
||||
# # 查询 duckdb_internal 系统表,获取扩展列表
|
||||
# result = con.execute("SELECT * FROM duckdb_internal.functions WHERE schema='list_extensions';")
|
||||
#
|
||||
# # 遍历查询结果,输出扩展名称和版本号
|
||||
# for row in result:
|
||||
# print(row['name'], row['return_type'])
|
||||
# duckdb.read_csv('/Users/tuyang.yhj/Downloads/example_csc.csv')
|
||||
# result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/yhj-zx.csv" ')
|
||||
# result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/example_csc.csv" limit 20')
|
||||
# for row in result.fetchall():
|
||||
# print(row)
|
||||
|
||||
|
||||
# result = con.execute("SELECT * FROM st_read('/Users/tuyang.yhj/Downloads/example.xlsx', layer='Sheet1')")
|
||||
# # 遍历查询结果
|
||||
# for row in result.fetchall():
|
||||
# print(row)
|
||||
print("xx")
|
||||
|
||||
|
||||
|
@@ -52,7 +52,3 @@ class DbChatOutputParser(BaseOutputParser):
|
||||
return data_loader.get_table_view_by_conn(data, speak)
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "sql_chat"
|
||||
|
Reference in New Issue
Block a user