refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

@@ -0,0 +1,87 @@
import os
import logging
from typing import Dict
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
from dbgpt.agent.commands.command_mange import ApiCall
from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
from dbgpt.util.path_utils import has_path
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt.util.tracer import root_tracer, trace
CFG = Config()
logger = logging.getLogger(__name__)
class ChatExcel(BaseChat):
"""a Excel analyzer to analyze Excel Data"""
chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 2
def __init__(self, chat_param: Dict):
"""Chat Excel Module Initialization
Args:
- chat_param: Dict
- chat_session_id: (str) chat session_id
- current_user_input: (str) current user input
- model_name:(str) llm model name
- select_param:(str) file path
"""
chat_mode = ChatScene.ChatExcel
self.select_param = chat_param["select_param"]
self.model_name = chat_param["model_name"]
chat_param["chat_mode"] = ChatScene.ChatExcel
if has_path(self.select_param):
self.excel_reader = ExcelReader(self.select_param)
else:
self.excel_reader = ExcelReader(
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)
self.api_call = ApiCall(display_registry=CFG.command_disply)
super().__init__(chat_param=chat_param)
@trace()
async def generate_input_values(self) -> Dict:
input_values = {
"user_input": self.current_user_input,
"table_name": self.excel_reader.table_name,
"disply_type": self._generate_numbered_list(),
}
return input_values
async def prepare(self):
logger.info(f"{self.chat_mode} prepare start!")
if len(self.history_message) > 0:
return None
chat_param = {
"chat_session_id": self.chat_session_id,
"user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analyze",
"parent_mode": self.chat_mode,
"select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader,
"model_name": self.model_name,
}
learn_chat = ExcelLearning(**chat_param)
result = await learn_chat.nostream_call()
return result
def stream_plugin_call(self, text):
text = (
text.replace("\\n", " ")
.replace("\n", " ")
.replace("\_", "_")
.replace("\\", " ")
)
with root_tracer.start_span(
"ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
):
return self.api_call.display_sql_llmvis(
text, self.excel_reader.get_df_by_sql_ex
)

View File

@@ -0,0 +1,41 @@
import json
import logging
from typing import NamedTuple
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt._private.config import Config
CFG = Config()
class ExcelAnalyzeResponse(NamedTuple):
sql: str
thoughts: str
display: str
logger = logging.getLogger(__name__)
class ChatExcelOutputParser(BaseOutputParser):
def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
try:
response = json.loads(clean_str)
for key in sorted(response):
if key.strip() == "sql":
sql = response[key].replace("\\", " ")
if key.strip() == "thoughts":
thoughts = response[key]
if key.strip() == "display":
display = response[key]
return ExcelAnalyzeResponse(sql, thoughts, display)
except Exception as e:
raise ValueError(f"LLM Response Can't Parser! \n")
def parse_view_response(self, speak, data, prompt_response) -> str:
### tool out data to table view
return data

View File

@@ -0,0 +1,73 @@
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.out_parser import (
ChatExcelOutputParser,
)
CFG = Config()
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """
Please use the data structure column analysis information generated in the above historical dialogue to answer the user's questions through duckdb sql data analysis under the following constraints..
Constraint:
1.Please fully understand the user's problem and use duckdb sql for analysis. The analysis content is returned in the output format required below. Please output the sql in the corresponding sql parameter.
2.Please choose the best one from the display methods given below for data rendering, and put the type name into the name parameter value that returns the required format. If you cannot find the most suitable one, use 'Table' as the display method. , the available data display methods are as follows: {disply_type}
3.The table name that needs to be used in SQL is: {table_name}. Please check the sql you generated and do not use column names that are not in the data structure.
4.Give priority to answering using data analysis. If the user's question does not involve data analysis, you can answer according to your understanding.
5.The sql part of the output content is converted to: <api-call><name>[data display mode]</name><args><sql>[correct duckdb data analysis sql]</sql></args></api - call> For this format, please refer to the return format requirements.
Please think step by step and give your answer, and make sure your answer is formatted as follows:
thoughts summary to say to user.<api-call><name>[Data display method]</name><args><sql>[Correct duckdb data analysis sql]</sql></args></api-call>
User Questions:
{user_input}
"""
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
_DEFAULT_TEMPLATE_ZH = """
请使用历史对话中的数据结构信息在满足下面约束条件下通过duckdb sql数据分析回答用户的问题。
约束条件:
1.请充分理解用户的问题使用duckdb sql的方式进行分析 分析内容按下面要求的输出格式返回sql请输出在对应的sql参数中
2.请从如下给出的展示方式种选择最优的一种用以进行数据渲染将类型名称放入返回要求格式的name参数值种如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {disply_type}
3.SQL中需要使用的表名是: {table_name},请检查你生成的sql不要使用没在数据结构中的列名
4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
5.输出内容中sql部分转换为<api-call><name>[数据显示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api- call> 这样的格式,参考返回格式要求
请一步一步思考,给出回答,并确保你的回答内容格式如下:
对用户说的想法摘要.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
用户问题:{user_input}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
_PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
PROMPT_NEED_STREAM_OUT = True
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.3
prompt = PromptTemplate(
template_scene=ChatScene.ChatExcel.value(),
input_variables=["user_input", "table_name", "disply_type"],
template_define=_PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=ChatExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
need_historical_messages=True,
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -0,0 +1,62 @@
import json
from typing import Any, Dict
from dbgpt.core.interface.message import ViewMessage, AIMessage
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.json_utils import DateTimeEncoder
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.tracer import trace
class ExcelLearning(BaseChat):
chat_scene: str = ChatScene.ExcelLearning.value()
def __init__(
self,
chat_session_id,
user_input,
parent_mode: Any = None,
select_param: str = None,
excel_reader: Any = None,
model_name: str = None,
):
chat_mode = ChatScene.ExcelLearning
""" """
self.excel_file_path = select_param
self.excel_reader = excel_reader
chat_param = {
"chat_mode": chat_mode,
"chat_session_id": chat_session_id,
"current_user_input": user_input,
"select_param": select_param,
"model_name": model_name,
}
super().__init__(chat_param=chat_param)
if parent_mode:
self.current_message.chat_mode = parent_mode.value()
@trace()
async def generate_input_values(self) -> Dict:
# colunms, datas = self.excel_reader.get_sample_data()
colunms, datas = await blocking_func_to_async(
self._executor, self.excel_reader.get_sample_data
)
self.prompt_template.output_parser.update(colunms)
datas.insert(0, colunms)
input_values = {
"data_example": json.dumps(datas, cls=DateTimeEncoder),
"file_name": self.excel_reader.excel_file_name,
}
return input_values
def message_adjust(self):
### adjust learning result in messages
view_message = ""
for message in self.current_message.messages:
if message.type == ViewMessage.type:
view_message = message.content
for message in self.current_message.messages:
if message.type == AIMessage.type:
message.content = view_message

View File

@@ -0,0 +1,72 @@
import json
import logging
from typing import NamedTuple, List
from dbgpt.core.interface.output_parser import BaseOutputParser
class ExcelResponse(NamedTuple):
desciption: str
clounms: List
plans: List
logger = logging.getLogger(__name__)
class LearningExcelOutputParser(BaseOutputParser):
def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)
self.is_downgraded = False
def parse_prompt_response(self, model_out_text):
try:
clean_str = super().parse_prompt_response(model_out_text)
logger.info(f"parse_prompt_response:{model_out_text},{model_out_text}")
response = json.loads(clean_str)
for key in sorted(response):
if key.strip() == "DataAnalysis":
desciption = response[key]
if key.strip() == "ColumnAnalysis":
clounms = response[key]
if key.strip() == "AnalysisProgram":
plans = response[key]
return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans)
except Exception as e:
logger.error(f"parse_prompt_response Faild!{str(e)}")
clounms = []
for name in self.data_schema:
clounms.append({name: "-"})
return ExcelResponse(desciption=model_out_text, clounms=clounms, plans=None)
def __build_colunms_html(self, clounms_data):
html_colunms = f"### **Data Structure**\n"
column_index = 0
for item in clounms_data:
column_index += 1
keys = item.keys()
for key in keys:
html_colunms = (
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
)
return html_colunms
def __build_plans_html(self, plans_data):
html_plans = f"### **Analysis plans**\n"
index = 0
if plans_data:
for item in plans_data:
index += 1
html_plans = html_plans + f"{item} \n"
return html_plans
def parse_view_response(self, speak, data, prompt_response) -> str:
if data and not isinstance(data, str):
### tool out data to table view
html_title = f"### **Data Summary**\n{data.desciption} "
html_colunms = self.__build_colunms_html(data.clounms)
html_plans = self.__build_plans_html(data.plans)
html = f"""{html_title}\n{html_colunms}\n{html_plans}"""
return html
else:
return speak

View File

@@ -0,0 +1,86 @@
import json
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.out_parser import (
LearningExcelOutputParser,
)
CFG = Config()
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """
The following is part of the data of the user file {file_name}. Please learn to understand the structure and content of the data and output the parsing results as required:
{data_example}
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms If it is a Date column, please summarize the Date format like: yyyy-MM-dd HH:MM:ss.
Use the column name as the attribute name and the analysis explanation as the attribute value to form a json array and output it in the ColumnAnalysis attribute that returns the json content.
Please do not modify or translate the column names, make sure they are consistent with the given data column names.
Provide some useful analysis ideas to users from different dimensions for data.
Please think step by step and give your answer. Make sure to answer only in JSON formatthe format is as follows:
{response}
"""
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据分析专家. "
_DEFAULT_TEMPLATE_ZH = """
下面是用户文件{file_name}的一部分数据,请学习理解该数据的结构和内容,按要求输出解析结果:
{data_example}
分析各列数据的含义和作用,并对专业术语进行简单明了的解释, 如果是时间类型请给出时间格式类似:yyyy-MM-dd HH:MM:ss.
将列名作为属性名,分析解释作为属性值,组成json数组并输出在返回json内容的ColumnAnalysis属性中.
请不要修改或者翻译列名,确保和给出数据列名一致.
针对数据从不同维度提供一些有用的分析思路给用户。
请一步一步思考,确保只以JSON格式回答具体格式如下
{response}
"""
_RESPONSE_FORMAT_SIMPLE_ZH = {
"DataAnalysis": "数据内容分析总结",
"ColumnAnalysis": [{"column name": "字段1介绍专业术语解释(请尽量简单明了)"}],
"AnalysisProgram": ["1.分析方案1", "2.分析方案2"],
}
_RESPONSE_FORMAT_SIMPLE_EN = {
"DataAnalysis": "Data content analysis summary",
"ColumnAnalysis": [
{
"column name": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)"
}
],
"AnalysisProgram": ["1. Analysis plan ", "2. Analysis plan "],
}
RESPONSE_FORMAT_SIMPLE = (
_RESPONSE_FORMAT_SIMPLE_EN if CFG.LANGUAGE == "en" else _RESPONSE_FORMAT_SIMPLE_ZH
)
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
PROMPT_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.8
prompt = PromptTemplate(
template_scene=ChatScene.ExcelLearning.value(),
input_variables=["data_example"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=LearningExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -0,0 +1,23 @@
import re
import sqlparse
def add_quotes(sql, column_names=[]):
parsed = sqlparse.parse(sql)
for stmt in parsed:
for token in stmt.tokens:
deep_quotes(token, column_names)
return str(parsed[0])
def deep_quotes(token, column_names=[]):
if hasattr(token, "tokens"):
for token_child in token.tokens:
deep_quotes(token_child, column_names)
else:
if token.ttype == sqlparse.tokens.Name:
if len(column_names) > 0:
if token.value in column_names:
token.value = f'"{token.value}"'
else:
token.value = f'"{token.value}"'

View File

@@ -0,0 +1,302 @@
import logging
import duckdb
import os
import sqlparse
import pandas as pd
import chardet
import numpy as np
from pyparsing import (
CaselessKeyword,
Word,
alphanums,
delimitedList,
Forward,
Optional,
Literal,
Regex,
)
from dbgpt.util.pd_utils import csv_colunm_foramt
from dbgpt.util.string_utils import is_chinese_include_number
logger = logging.getLogger(__name__)
def excel_colunm_format(old_name: str) -> str:
new_column = old_name.strip()
new_column = new_column.replace(" ", "_")
return new_column
def detect_encoding(file_path):
# 读取文件的二进制数据
with open(file_path, "rb") as f:
data = f.read()
# 使用 chardet 来检测文件编码
result = chardet.detect(data)
encoding = result["encoding"]
confidence = result["confidence"]
return encoding, confidence
def add_quotes_ex(sql: str, column_names):
sql = sql.replace("`", '"')
for column_name in column_names:
if sql.find(column_name) != -1 and sql.find(f'"{column_name}"') == -1:
sql = sql.replace(column_name, f'"{column_name}"')
return sql
def parse_sql(sql):
# 定义关键字和标识符
select_stmt = Forward()
column = Regex(r"[\w一-龥]*")
table = Word(alphanums)
join_expr = Forward()
where_expr = Forward()
group_by_expr = Forward()
order_by_expr = Forward()
select_keyword = CaselessKeyword("SELECT")
from_keyword = CaselessKeyword("FROM")
join_keyword = CaselessKeyword("JOIN")
on_keyword = CaselessKeyword("ON")
where_keyword = CaselessKeyword("WHERE")
group_by_keyword = CaselessKeyword("GROUP BY")
order_by_keyword = CaselessKeyword("ORDER BY")
and_keyword = CaselessKeyword("AND")
or_keyword = CaselessKeyword("OR")
in_keyword = CaselessKeyword("IN")
not_in_keyword = CaselessKeyword("NOT IN")
# 定义语法规则
select_stmt <<= (
select_keyword
+ delimitedList(column)
+ from_keyword
+ delimitedList(table)
+ Optional(join_expr)
+ Optional(where_keyword + where_expr)
+ Optional(group_by_keyword + group_by_expr)
+ Optional(order_by_keyword + order_by_expr)
)
join_expr <<= join_keyword + table + on_keyword + column + Literal("=") + column
where_expr <<= (
column + Literal("=") + Word(alphanums) + Optional(and_keyword + where_expr)
| column + Literal(">") + Word(alphanums) + Optional(and_keyword + where_expr)
| column + Literal("<") + Word(alphanums) + Optional(and_keyword + where_expr)
)
group_by_expr <<= delimitedList(column)
order_by_expr <<= column + Optional(Literal("ASC") | Literal("DESC"))
# 解析 SQL 语句
parsed_result = select_stmt.parseString(sql)
return parsed_result.asList()
def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "")
sql = sql.replace("'", "")
parsed = sqlparse.parse(sql)
for stmt in parsed:
for token in stmt.tokens:
deep_quotes(token, column_names)
return str(parsed[0])
def deep_quotes(token, column_names=[]):
if hasattr(token, "tokens"):
for token_child in token.tokens:
deep_quotes(token_child, column_names)
else:
if is_chinese_include_number(token.value):
new_value = token.value.replace("`", "").replace("'", "")
token.value = f'"{new_value}"'
def get_select_clause(sql):
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
select_tokens = []
is_select = False
for token in parsed.tokens:
if token.is_keyword and token.value.upper() == "SELECT":
is_select = True
elif is_select:
if token.is_keyword and token.value.upper() == "FROM":
break
select_tokens.append(token)
return "".join(str(token) for token in select_tokens)
def parse_select_fields(sql):
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
fields = []
for token in parsed.tokens:
# 使用 flatten() 方法合并 '2022' 和 '年' 为一个 token
if token.match(sqlparse.tokens.Literal.String.Single):
token.flatten()
if isinstance(token, sqlparse.sql.Identifier):
fields.append(token.get_real_name())
# 处理中文
fields = [field.replace(f"field", f'"{field}"') for field in fields]
return fields
def add_quotes_to_chinese_columns(sql, column_names=[]):
parsed = sqlparse.parse(sql)
for stmt in parsed:
process_statement(stmt, column_names)
return str(parsed[0])
def process_statement(statement, column_names=[]):
if isinstance(statement, sqlparse.sql.IdentifierList):
for identifier in statement.get_identifiers():
process_identifier(identifier)
elif isinstance(statement, sqlparse.sql.Identifier):
process_identifier(statement, column_names)
elif isinstance(statement, sqlparse.sql.TokenList):
for item in statement.tokens:
process_statement(item)
def process_identifier(identifier, column_names=[]):
# if identifier.has_alias():
# alias = identifier.get_alias()
# identifier.tokens[-1].value = '[' + alias + ']'
if hasattr(identifier, "tokens") and identifier.value in column_names:
if is_chinese(identifier.value):
new_value = get_new_value(identifier.value)
identifier.value = new_value
identifier.normalized = new_value
identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
else:
if hasattr(identifier, "tokens"):
for token in identifier.tokens:
if isinstance(token, sqlparse.sql.Function):
process_function(token)
elif token.ttype in sqlparse.tokens.Name:
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
elif token.value in column_names:
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
def get_new_value(value):
return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """
def process_function(function):
function_params = list(function.get_parameters())
# for param in function_params:
for i in range(len(function_params)):
param = function_params[i]
# 如果参数部分是一个标识符(字段名)
if isinstance(param, sqlparse.sql.Identifier):
# 判断是否需要替换字段值
# if is_chinese(param.value):
# 替换字段值
new_value = get_new_value(param.value)
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
function_params[i].tokens = [
sqlparse.sql.Token(sqlparse.tokens.Name, new_value)
]
print(str(function))
def is_chinese(text):
for char in text:
if "" <= char <= "鿿":
return True
return False
class ExcelReader:
def __init__(self, file_path):
file_name = os.path.basename(file_path)
self.file_name_without_extension = os.path.splitext(file_name)[0]
encoding, confidence = detect_encoding(file_path)
logger.info(f"Detected Encoding: {encoding} (Confidence: {confidence})")
self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1]
# read excel file
if file_path.endswith(".xlsx") or file_path.endswith(".xls"):
df_tmp = pd.read_excel(file_path, index_col=False)
self.df = pd.read_excel(
file_path,
index_col=False,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
elif file_path.endswith(".csv"):
df_tmp = pd.read_csv(file_path, index_col=False, encoding=encoding)
self.df = pd.read_csv(
file_path,
index_col=False,
encoding=encoding,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
else:
raise ValueError("Unsupported file format.")
self.df.replace("", np.nan, inplace=True)
self.columns_map = {}
for column_name in df_tmp.columns:
self.columns_map.update({column_name: excel_colunm_format(column_name)})
try:
if not pd.api.types.is_datetime64_ns_dtype(self.df[column_name]):
self.df[column_name] = pd.to_numeric(self.df[column_name])
self.df[column_name] = self.df[column_name].fillna(0)
except Exception as e:
print("can't transfor numeric column" + column_name)
self.df = self.df.rename(columns=lambda x: x.strip().replace(" ", "_"))
# connect DuckDB
self.db = duckdb.connect(database=":memory:", read_only=False)
self.table_name = "excel_data"
# write data in duckdb
self.db.register(self.table_name, self.df)
# 获取结果并打印表结构信息
result = self.db.execute(f"DESCRIBE {self.table_name}")
columns = result.fetchall()
for column in columns:
print(column)
def run(self, sql):
try:
if f'"{self.table_name}"' in sql:
sql = sql.replace(f'"{self.table_name}"', self.table_name)
sql = add_quotes_to_chinese_columns(sql)
print(f"excute sql:{sql}")
results = self.db.execute(sql)
colunms = []
for descrip in results.description:
colunms.append(descrip[0])
return colunms, results.fetchall()
except Exception as e:
logger.error(f"excel sql run error!, {str(e)}")
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
def get_df_by_sql_ex(self, sql):
colunms, values = self.run(sql)
return pd.DataFrame(values, columns=colunms)
def get_sample_data(self):
return self.run(f"SELECT * FROM {self.table_name} LIMIT 5;")