mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 13:10:29 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/app/scene/chat_data/chat_excel/__init__.py
Normal file
0
dbgpt/app/scene/chat_data/chat_excel/__init__.py
Normal file
87
dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py
Normal file
87
dbgpt/app/scene/chat_data/chat_excel/excel_analyze/chat.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
from typing import Dict
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.commands.command_mange import ApiCall
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
||||
from dbgpt.util.path_utils import has_path
|
||||
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatExcel(BaseChat):
|
||||
"""a Excel analyzer to analyze Excel Data"""
|
||||
|
||||
chat_scene: str = ChatScene.ChatExcel.value()
|
||||
chat_retention_rounds = 2
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
"""Chat Excel Module Initialization
|
||||
Args:
|
||||
- chat_param: Dict
|
||||
- chat_session_id: (str) chat session_id
|
||||
- current_user_input: (str) current user input
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) file path
|
||||
"""
|
||||
chat_mode = ChatScene.ChatExcel
|
||||
|
||||
self.select_param = chat_param["select_param"]
|
||||
self.model_name = chat_param["model_name"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatExcel
|
||||
if has_path(self.select_param):
|
||||
self.excel_reader = ExcelReader(self.select_param)
|
||||
else:
|
||||
self.excel_reader = ExcelReader(
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
|
||||
)
|
||||
)
|
||||
self.api_call = ApiCall(display_registry=CFG.command_disply)
|
||||
super().__init__(chat_param=chat_param)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {
|
||||
"user_input": self.current_user_input,
|
||||
"table_name": self.excel_reader.table_name,
|
||||
"disply_type": self._generate_numbered_list(),
|
||||
}
|
||||
return input_values
|
||||
|
||||
async def prepare(self):
|
||||
logger.info(f"{self.chat_mode} prepare start!")
|
||||
if len(self.history_message) > 0:
|
||||
return None
|
||||
chat_param = {
|
||||
"chat_session_id": self.chat_session_id,
|
||||
"user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analyze!",
|
||||
"parent_mode": self.chat_mode,
|
||||
"select_param": self.excel_reader.excel_file_name,
|
||||
"excel_reader": self.excel_reader,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
learn_chat = ExcelLearning(**chat_param)
|
||||
result = await learn_chat.nostream_call()
|
||||
return result
|
||||
|
||||
def stream_plugin_call(self, text):
|
||||
text = (
|
||||
text.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\_", "_")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
|
||||
):
|
||||
return self.api_call.display_sql_llmvis(
|
||||
text, self.excel_reader.get_df_by_sql_ex
|
||||
)
|
@@ -0,0 +1,41 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import NamedTuple
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ExcelAnalyzeResponse(NamedTuple):
|
||||
sql: str
|
||||
thoughts: str
|
||||
display: str
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatExcelOutputParser(BaseOutputParser):
|
||||
def __init__(self, is_stream_out: bool, **kwargs):
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
try:
|
||||
response = json.loads(clean_str)
|
||||
for key in sorted(response):
|
||||
if key.strip() == "sql":
|
||||
sql = response[key].replace("\\", " ")
|
||||
if key.strip() == "thoughts":
|
||||
thoughts = response[key]
|
||||
if key.strip() == "display":
|
||||
display = response[key]
|
||||
return ExcelAnalyzeResponse(sql, thoughts, display)
|
||||
except Exception as e:
|
||||
raise ValueError(f"LLM Response Can't Parser! \n")
|
||||
|
||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||
### tool out data to table view
|
||||
return data
|
73
dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py
Normal file
73
dbgpt/app/scene/chat_data/chat_excel/excel_analyze/prompt.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.out_parser import (
|
||||
ChatExcelOutputParser,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Please use the data structure column analysis information generated in the above historical dialogue to answer the user's questions through duckdb sql data analysis under the following constraints..
|
||||
|
||||
Constraint:
|
||||
1.Please fully understand the user's problem and use duckdb sql for analysis. The analysis content is returned in the output format required below. Please output the sql in the corresponding sql parameter.
|
||||
2.Please choose the best one from the display methods given below for data rendering, and put the type name into the name parameter value that returns the required format. If you cannot find the most suitable one, use 'Table' as the display method. , the available data display methods are as follows: {disply_type}
|
||||
3.The table name that needs to be used in SQL is: {table_name}. Please check the sql you generated and do not use column names that are not in the data structure.
|
||||
4.Give priority to answering using data analysis. If the user's question does not involve data analysis, you can answer according to your understanding.
|
||||
5.The sql part of the output content is converted to: <api-call><name>[data display mode]</name><args><sql>[correct duckdb data analysis sql]</sql></args></api - call> For this format, please refer to the return format requirements.
|
||||
|
||||
Please think step by step and give your answer, and make sure your answer is formatted as follows:
|
||||
thoughts summary to say to user.<api-call><name>[Data display method]</name><args><sql>[Correct duckdb data analysis sql]</sql></args></api-call>
|
||||
|
||||
User Questions:
|
||||
{user_input}
|
||||
"""
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
请使用历史对话中的数据结构信息,在满足下面约束条件下通过duckdb sql数据分析回答用户的问题。
|
||||
约束条件:
|
||||
1.请充分理解用户的问题,使用duckdb sql的方式进行分析, 分析内容按下面要求的输出格式返回,sql请输出在对应的sql参数中
|
||||
2.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {disply_type}
|
||||
3.SQL中需要使用的表名是: {table_name},请检查你生成的sql,不要使用没在数据结构中的列名
|
||||
4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
|
||||
5.输出内容中sql部分转换为:<api-call><name>[数据显示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api- call> 这样的格式,参考返回格式要求
|
||||
|
||||
请一步一步思考,给出回答,并确保你的回答内容格式如下:
|
||||
对用户说的想法摘要.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
|
||||
用户问题:{user_input}
|
||||
"""
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
_PROMPT_SCENE_DEFINE = (
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_NEED_STREAM_OUT = True
|
||||
|
||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||
PROMPT_TEMPERATURE = 0.3
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExcel.value(),
|
||||
input_variables=["user_input", "table_name", "disply_type"],
|
||||
template_define=_PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=ChatExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
need_historical_messages=True,
|
||||
# example_selector=sql_data_example,
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
62
dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py
Normal file
62
dbgpt/app/scene/chat_data/chat_excel/excel_learning/chat.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from dbgpt.core.interface.message import ViewMessage, AIMessage
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.util.json_utils import DateTimeEncoder
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.tracer import trace
|
||||
|
||||
|
||||
class ExcelLearning(BaseChat):
|
||||
chat_scene: str = ChatScene.ExcelLearning.value()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
parent_mode: Any = None,
|
||||
select_param: str = None,
|
||||
excel_reader: Any = None,
|
||||
model_name: str = None,
|
||||
):
|
||||
chat_mode = ChatScene.ExcelLearning
|
||||
""" """
|
||||
self.excel_file_path = select_param
|
||||
self.excel_reader = excel_reader
|
||||
chat_param = {
|
||||
"chat_mode": chat_mode,
|
||||
"chat_session_id": chat_session_id,
|
||||
"current_user_input": user_input,
|
||||
"select_param": select_param,
|
||||
"model_name": model_name,
|
||||
}
|
||||
super().__init__(chat_param=chat_param)
|
||||
if parent_mode:
|
||||
self.current_message.chat_mode = parent_mode.value()
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
# colunms, datas = self.excel_reader.get_sample_data()
|
||||
colunms, datas = await blocking_func_to_async(
|
||||
self._executor, self.excel_reader.get_sample_data
|
||||
)
|
||||
self.prompt_template.output_parser.update(colunms)
|
||||
datas.insert(0, colunms)
|
||||
|
||||
input_values = {
|
||||
"data_example": json.dumps(datas, cls=DateTimeEncoder),
|
||||
"file_name": self.excel_reader.excel_file_name,
|
||||
}
|
||||
return input_values
|
||||
|
||||
def message_adjust(self):
|
||||
### adjust learning result in messages
|
||||
view_message = ""
|
||||
for message in self.current_message.messages:
|
||||
if message.type == ViewMessage.type:
|
||||
view_message = message.content
|
||||
|
||||
for message in self.current_message.messages:
|
||||
if message.type == AIMessage.type:
|
||||
message.content = view_message
|
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import NamedTuple, List
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
|
||||
|
||||
class ExcelResponse(NamedTuple):
|
||||
desciption: str
|
||||
clounms: List
|
||||
plans: List
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LearningExcelOutputParser(BaseOutputParser):
|
||||
def __init__(self, is_stream_out: bool, **kwargs):
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
self.is_downgraded = False
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
try:
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
logger.info(f"parse_prompt_response:{model_out_text},{model_out_text}")
|
||||
response = json.loads(clean_str)
|
||||
for key in sorted(response):
|
||||
if key.strip() == "DataAnalysis":
|
||||
desciption = response[key]
|
||||
if key.strip() == "ColumnAnalysis":
|
||||
clounms = response[key]
|
||||
if key.strip() == "AnalysisProgram":
|
||||
plans = response[key]
|
||||
return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans)
|
||||
except Exception as e:
|
||||
logger.error(f"parse_prompt_response Faild!{str(e)}")
|
||||
clounms = []
|
||||
for name in self.data_schema:
|
||||
clounms.append({name: "-"})
|
||||
return ExcelResponse(desciption=model_out_text, clounms=clounms, plans=None)
|
||||
|
||||
def __build_colunms_html(self, clounms_data):
|
||||
html_colunms = f"### **Data Structure**\n"
|
||||
column_index = 0
|
||||
for item in clounms_data:
|
||||
column_index += 1
|
||||
keys = item.keys()
|
||||
for key in keys:
|
||||
html_colunms = (
|
||||
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
|
||||
)
|
||||
return html_colunms
|
||||
|
||||
def __build_plans_html(self, plans_data):
|
||||
html_plans = f"### **Analysis plans**\n"
|
||||
index = 0
|
||||
if plans_data:
|
||||
for item in plans_data:
|
||||
index += 1
|
||||
html_plans = html_plans + f"{item} \n"
|
||||
return html_plans
|
||||
|
||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||
if data and not isinstance(data, str):
|
||||
### tool out data to table view
|
||||
html_title = f"### **Data Summary**\n{data.desciption} "
|
||||
html_colunms = self.__build_colunms_html(data.clounms)
|
||||
html_plans = self.__build_plans_html(data.plans)
|
||||
|
||||
html = f"""{html_title}\n{html_colunms}\n{html_plans}"""
|
||||
return html
|
||||
else:
|
||||
return speak
|
@@ -0,0 +1,86 @@
|
||||
import json
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.out_parser import (
|
||||
LearningExcelOutputParser,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
The following is part of the data of the user file {file_name}. Please learn to understand the structure and content of the data and output the parsing results as required:
|
||||
{data_example}
|
||||
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms, If it is a Date column, please summarize the Date format like: yyyy-MM-dd HH:MM:ss.
|
||||
Use the column name as the attribute name and the analysis explanation as the attribute value to form a json array and output it in the ColumnAnalysis attribute that returns the json content.
|
||||
Please do not modify or translate the column names, make sure they are consistent with the given data column names.
|
||||
Provide some useful analysis ideas to users from different dimensions for data.
|
||||
|
||||
Please think step by step and give your answer. Make sure to answer only in JSON format,the format is as follows:
|
||||
{response}
|
||||
"""
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据分析专家. "
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
下面是用户文件{file_name}的一部分数据,请学习理解该数据的结构和内容,按要求输出解析结果:
|
||||
{data_example}
|
||||
分析各列数据的含义和作用,并对专业术语进行简单明了的解释, 如果是时间类型请给出时间格式类似:yyyy-MM-dd HH:MM:ss.
|
||||
将列名作为属性名,分析解释作为属性值,组成json数组,并输出在返回json内容的ColumnAnalysis属性中.
|
||||
请不要修改或者翻译列名,确保和给出数据列名一致.
|
||||
针对数据从不同维度提供一些有用的分析思路给用户。
|
||||
|
||||
请一步一步思考,确保只以JSON格式回答,具体格式如下:
|
||||
{response}
|
||||
"""
|
||||
|
||||
_RESPONSE_FORMAT_SIMPLE_ZH = {
|
||||
"DataAnalysis": "数据内容分析总结",
|
||||
"ColumnAnalysis": [{"column name": "字段1介绍,专业术语解释(请尽量简单明了)"}],
|
||||
"AnalysisProgram": ["1.分析方案1", "2.分析方案2"],
|
||||
}
|
||||
_RESPONSE_FORMAT_SIMPLE_EN = {
|
||||
"DataAnalysis": "Data content analysis summary",
|
||||
"ColumnAnalysis": [
|
||||
{
|
||||
"column name": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)"
|
||||
}
|
||||
],
|
||||
"AnalysisProgram": ["1. Analysis plan ", "2. Analysis plan "],
|
||||
}
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = (
|
||||
_RESPONSE_FORMAT_SIMPLE_EN if CFG.LANGUAGE == "en" else _RESPONSE_FORMAT_SIMPLE_ZH
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_SCENE_DEFINE = (
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_NEED_STREAM_OUT = False
|
||||
|
||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||
PROMPT_TEMPERATURE = 0.8
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExcelLearning.value(),
|
||||
input_variables=["data_example"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=LearningExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
# example_selector=sql_data_example,
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@@ -0,0 +1,23 @@
|
||||
import re
|
||||
import sqlparse
|
||||
|
||||
|
||||
def add_quotes(sql, column_names=[]):
|
||||
parsed = sqlparse.parse(sql)
|
||||
for stmt in parsed:
|
||||
for token in stmt.tokens:
|
||||
deep_quotes(token, column_names)
|
||||
return str(parsed[0])
|
||||
|
||||
|
||||
def deep_quotes(token, column_names=[]):
|
||||
if hasattr(token, "tokens"):
|
||||
for token_child in token.tokens:
|
||||
deep_quotes(token_child, column_names)
|
||||
else:
|
||||
if token.ttype == sqlparse.tokens.Name:
|
||||
if len(column_names) > 0:
|
||||
if token.value in column_names:
|
||||
token.value = f'"{token.value}"'
|
||||
else:
|
||||
token.value = f'"{token.value}"'
|
302
dbgpt/app/scene/chat_data/chat_excel/excel_reader.py
Normal file
302
dbgpt/app/scene/chat_data/chat_excel/excel_reader.py
Normal file
@@ -0,0 +1,302 @@
|
||||
import logging
|
||||
|
||||
import duckdb
|
||||
import os
|
||||
import sqlparse
|
||||
import pandas as pd
|
||||
import chardet
|
||||
import numpy as np
|
||||
from pyparsing import (
|
||||
CaselessKeyword,
|
||||
Word,
|
||||
alphanums,
|
||||
delimitedList,
|
||||
Forward,
|
||||
Optional,
|
||||
Literal,
|
||||
Regex,
|
||||
)
|
||||
|
||||
from dbgpt.util.pd_utils import csv_colunm_foramt
|
||||
from dbgpt.util.string_utils import is_chinese_include_number
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def excel_colunm_format(old_name: str) -> str:
|
||||
new_column = old_name.strip()
|
||||
new_column = new_column.replace(" ", "_")
|
||||
return new_column
|
||||
|
||||
|
||||
def detect_encoding(file_path):
|
||||
# 读取文件的二进制数据
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read()
|
||||
# 使用 chardet 来检测文件编码
|
||||
result = chardet.detect(data)
|
||||
encoding = result["encoding"]
|
||||
confidence = result["confidence"]
|
||||
return encoding, confidence
|
||||
|
||||
|
||||
def add_quotes_ex(sql: str, column_names):
|
||||
sql = sql.replace("`", '"')
|
||||
for column_name in column_names:
|
||||
if sql.find(column_name) != -1 and sql.find(f'"{column_name}"') == -1:
|
||||
sql = sql.replace(column_name, f'"{column_name}"')
|
||||
return sql
|
||||
|
||||
|
||||
def parse_sql(sql):
|
||||
# 定义关键字和标识符
|
||||
select_stmt = Forward()
|
||||
column = Regex(r"[\w一-龥]*")
|
||||
table = Word(alphanums)
|
||||
join_expr = Forward()
|
||||
where_expr = Forward()
|
||||
group_by_expr = Forward()
|
||||
order_by_expr = Forward()
|
||||
|
||||
select_keyword = CaselessKeyword("SELECT")
|
||||
from_keyword = CaselessKeyword("FROM")
|
||||
join_keyword = CaselessKeyword("JOIN")
|
||||
on_keyword = CaselessKeyword("ON")
|
||||
where_keyword = CaselessKeyword("WHERE")
|
||||
group_by_keyword = CaselessKeyword("GROUP BY")
|
||||
order_by_keyword = CaselessKeyword("ORDER BY")
|
||||
and_keyword = CaselessKeyword("AND")
|
||||
or_keyword = CaselessKeyword("OR")
|
||||
in_keyword = CaselessKeyword("IN")
|
||||
not_in_keyword = CaselessKeyword("NOT IN")
|
||||
|
||||
# 定义语法规则
|
||||
select_stmt <<= (
|
||||
select_keyword
|
||||
+ delimitedList(column)
|
||||
+ from_keyword
|
||||
+ delimitedList(table)
|
||||
+ Optional(join_expr)
|
||||
+ Optional(where_keyword + where_expr)
|
||||
+ Optional(group_by_keyword + group_by_expr)
|
||||
+ Optional(order_by_keyword + order_by_expr)
|
||||
)
|
||||
|
||||
join_expr <<= join_keyword + table + on_keyword + column + Literal("=") + column
|
||||
|
||||
where_expr <<= (
|
||||
column + Literal("=") + Word(alphanums) + Optional(and_keyword + where_expr)
|
||||
| column + Literal(">") + Word(alphanums) + Optional(and_keyword + where_expr)
|
||||
| column + Literal("<") + Word(alphanums) + Optional(and_keyword + where_expr)
|
||||
)
|
||||
|
||||
group_by_expr <<= delimitedList(column)
|
||||
|
||||
order_by_expr <<= column + Optional(Literal("ASC") | Literal("DESC"))
|
||||
|
||||
# 解析 SQL 语句
|
||||
parsed_result = select_stmt.parseString(sql)
|
||||
|
||||
return parsed_result.asList()
|
||||
|
||||
|
||||
def add_quotes(sql, column_names=[]):
|
||||
sql = sql.replace("`", "")
|
||||
sql = sql.replace("'", "")
|
||||
parsed = sqlparse.parse(sql)
|
||||
for stmt in parsed:
|
||||
for token in stmt.tokens:
|
||||
deep_quotes(token, column_names)
|
||||
return str(parsed[0])
|
||||
|
||||
|
||||
def deep_quotes(token, column_names=[]):
|
||||
if hasattr(token, "tokens"):
|
||||
for token_child in token.tokens:
|
||||
deep_quotes(token_child, column_names)
|
||||
else:
|
||||
if is_chinese_include_number(token.value):
|
||||
new_value = token.value.replace("`", "").replace("'", "")
|
||||
token.value = f'"{new_value}"'
|
||||
|
||||
|
||||
def get_select_clause(sql):
|
||||
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
|
||||
|
||||
select_tokens = []
|
||||
is_select = False
|
||||
|
||||
for token in parsed.tokens:
|
||||
if token.is_keyword and token.value.upper() == "SELECT":
|
||||
is_select = True
|
||||
elif is_select:
|
||||
if token.is_keyword and token.value.upper() == "FROM":
|
||||
break
|
||||
select_tokens.append(token)
|
||||
return "".join(str(token) for token in select_tokens)
|
||||
|
||||
|
||||
def parse_select_fields(sql):
|
||||
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
|
||||
fields = []
|
||||
|
||||
for token in parsed.tokens:
|
||||
# 使用 flatten() 方法合并 '2022' 和 '年' 为一个 token
|
||||
if token.match(sqlparse.tokens.Literal.String.Single):
|
||||
token.flatten()
|
||||
if isinstance(token, sqlparse.sql.Identifier):
|
||||
fields.append(token.get_real_name())
|
||||
|
||||
# 处理中文
|
||||
fields = [field.replace(f"field", f'"{field}"') for field in fields]
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def add_quotes_to_chinese_columns(sql, column_names=[]):
|
||||
parsed = sqlparse.parse(sql)
|
||||
for stmt in parsed:
|
||||
process_statement(stmt, column_names)
|
||||
return str(parsed[0])
|
||||
|
||||
|
||||
def process_statement(statement, column_names=[]):
|
||||
if isinstance(statement, sqlparse.sql.IdentifierList):
|
||||
for identifier in statement.get_identifiers():
|
||||
process_identifier(identifier)
|
||||
elif isinstance(statement, sqlparse.sql.Identifier):
|
||||
process_identifier(statement, column_names)
|
||||
elif isinstance(statement, sqlparse.sql.TokenList):
|
||||
for item in statement.tokens:
|
||||
process_statement(item)
|
||||
|
||||
|
||||
def process_identifier(identifier, column_names=[]):
|
||||
# if identifier.has_alias():
|
||||
# alias = identifier.get_alias()
|
||||
# identifier.tokens[-1].value = '[' + alias + ']'
|
||||
if hasattr(identifier, "tokens") and identifier.value in column_names:
|
||||
if is_chinese(identifier.value):
|
||||
new_value = get_new_value(identifier.value)
|
||||
identifier.value = new_value
|
||||
identifier.normalized = new_value
|
||||
identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
else:
|
||||
if hasattr(identifier, "tokens"):
|
||||
for token in identifier.tokens:
|
||||
if isinstance(token, sqlparse.sql.Function):
|
||||
process_function(token)
|
||||
elif token.ttype in sqlparse.tokens.Name:
|
||||
new_value = get_new_value(token.value)
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
elif token.value in column_names:
|
||||
new_value = get_new_value(token.value)
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
|
||||
|
||||
def get_new_value(value):
|
||||
return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """
|
||||
|
||||
|
||||
def process_function(function):
|
||||
function_params = list(function.get_parameters())
|
||||
# for param in function_params:
|
||||
for i in range(len(function_params)):
|
||||
param = function_params[i]
|
||||
# 如果参数部分是一个标识符(字段名)
|
||||
if isinstance(param, sqlparse.sql.Identifier):
|
||||
# 判断是否需要替换字段值
|
||||
# if is_chinese(param.value):
|
||||
# 替换字段值
|
||||
new_value = get_new_value(param.value)
|
||||
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
|
||||
function_params[i].tokens = [
|
||||
sqlparse.sql.Token(sqlparse.tokens.Name, new_value)
|
||||
]
|
||||
print(str(function))
|
||||
|
||||
|
||||
def is_chinese(text):
|
||||
for char in text:
|
||||
if "一" <= char <= "鿿":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ExcelReader:
|
||||
def __init__(self, file_path):
|
||||
file_name = os.path.basename(file_path)
|
||||
self.file_name_without_extension = os.path.splitext(file_name)[0]
|
||||
encoding, confidence = detect_encoding(file_path)
|
||||
logger.info(f"Detected Encoding: {encoding} (Confidence: {confidence})")
|
||||
self.excel_file_name = file_name
|
||||
self.extension = os.path.splitext(file_name)[1]
|
||||
# read excel file
|
||||
if file_path.endswith(".xlsx") or file_path.endswith(".xls"):
|
||||
df_tmp = pd.read_excel(file_path, index_col=False)
|
||||
self.df = pd.read_excel(
|
||||
file_path,
|
||||
index_col=False,
|
||||
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
|
||||
)
|
||||
elif file_path.endswith(".csv"):
|
||||
df_tmp = pd.read_csv(file_path, index_col=False, encoding=encoding)
|
||||
self.df = pd.read_csv(
|
||||
file_path,
|
||||
index_col=False,
|
||||
encoding=encoding,
|
||||
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported file format.")
|
||||
|
||||
self.df.replace("", np.nan, inplace=True)
|
||||
self.columns_map = {}
|
||||
for column_name in df_tmp.columns:
|
||||
self.columns_map.update({column_name: excel_colunm_format(column_name)})
|
||||
try:
|
||||
if not pd.api.types.is_datetime64_ns_dtype(self.df[column_name]):
|
||||
self.df[column_name] = pd.to_numeric(self.df[column_name])
|
||||
self.df[column_name] = self.df[column_name].fillna(0)
|
||||
except Exception as e:
|
||||
print("can't transfor numeric column" + column_name)
|
||||
|
||||
self.df = self.df.rename(columns=lambda x: x.strip().replace(" ", "_"))
|
||||
|
||||
# connect DuckDB
|
||||
self.db = duckdb.connect(database=":memory:", read_only=False)
|
||||
|
||||
self.table_name = "excel_data"
|
||||
# write data in duckdb
|
||||
self.db.register(self.table_name, self.df)
|
||||
|
||||
# 获取结果并打印表结构信息
|
||||
result = self.db.execute(f"DESCRIBE {self.table_name}")
|
||||
columns = result.fetchall()
|
||||
for column in columns:
|
||||
print(column)
|
||||
|
||||
def run(self, sql):
|
||||
try:
|
||||
if f'"{self.table_name}"' in sql:
|
||||
sql = sql.replace(f'"{self.table_name}"', self.table_name)
|
||||
sql = add_quotes_to_chinese_columns(sql)
|
||||
print(f"excute sql:{sql}")
|
||||
results = self.db.execute(sql)
|
||||
colunms = []
|
||||
for descrip in results.description:
|
||||
colunms.append(descrip[0])
|
||||
return colunms, results.fetchall()
|
||||
except Exception as e:
|
||||
logger.error(f"excel sql run error!, {str(e)}")
|
||||
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
|
||||
|
||||
def get_df_by_sql_ex(self, sql):
|
||||
colunms, values = self.run(sql)
|
||||
return pd.DataFrame(values, columns=colunms)
|
||||
|
||||
def get_sample_data(self):
|
||||
return self.run(f"SELECT * FROM {self.table_name} LIMIT 5;")
|
Reference in New Issue
Block a user