mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 01:07:15 +00:00
chat with plugin bug fix
This commit is contained in:
parent
ced9b581fc
commit
6c2ab298a0
32
pilot/common/custom_data_structure.py
Normal file
32
pilot/common/custom_data_structure.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
class FixedSizeDict(OrderedDict):
|
||||||
|
def __init__(self, max_size):
|
||||||
|
super().__init__()
|
||||||
|
self.max_size = max_size
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
if len(self) >= self.max_size:
|
||||||
|
self.popitem(last=False)
|
||||||
|
super().__setitem__(key, value)
|
||||||
|
|
||||||
|
class FixedSizeList:
|
||||||
|
def __init__(self, max_size):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.list = deque(maxlen=max_size)
|
||||||
|
|
||||||
|
def append(self, value):
|
||||||
|
self.list.append(value)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.list[index]
|
||||||
|
|
||||||
|
def __setitem__(self, index, value):
|
||||||
|
self.list[index] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.list)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(list(self.list))
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import sqlparse
|
||||||
|
import regex as re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
@ -18,7 +19,6 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
|||||||
from sqlalchemy.schema import CreateTable
|
from sqlalchemy.schema import CreateTable
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||||
|
|
||||||
|
|
||||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||||
return (
|
return (
|
||||||
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
||||||
@ -143,6 +143,12 @@ class Database:
|
|||||||
)
|
)
|
||||||
return self.get_usable_table_names()
|
return self.get_usable_table_names()
|
||||||
|
|
||||||
|
def get_session_db(self, connect):
|
||||||
|
sql = text(f"select DATABASE()")
|
||||||
|
cursor = connect.execute(sql)
|
||||||
|
result = cursor.fetchone()[0]
|
||||||
|
return result
|
||||||
|
|
||||||
def get_session(self, db_name: str):
|
def get_session(self, db_name: str):
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
|
|
||||||
@ -275,10 +281,31 @@ class Database:
|
|||||||
"""Format the error message"""
|
"""Format the error message"""
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
def __write(self, session, write_sql):
|
||||||
"""Execute a SQL command and return a string representing the results."""
|
print(f"Write[{write_sql}]")
|
||||||
print("sql run:" + command)
|
db_cache = self.get_session_db(session)
|
||||||
cursor = session.execute(text(command))
|
result = session.execute(text(write_sql))
|
||||||
|
session.commit()
|
||||||
|
#TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||||
|
session.execute(text(f"use `{db_cache}`"))
|
||||||
|
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def __query(self,session, query, fetch: str = "all"):
|
||||||
|
"""
|
||||||
|
only for query
|
||||||
|
Args:
|
||||||
|
session:
|
||||||
|
query:
|
||||||
|
fetch:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
print(f"Query[{query}]")
|
||||||
|
if not query:
|
||||||
|
return []
|
||||||
|
cursor = session.execute(text(query))
|
||||||
if cursor.returns_rows:
|
if cursor.returns_rows:
|
||||||
if fetch == "all":
|
if fetch == "all":
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
@ -292,6 +319,36 @@ class Database:
|
|||||||
result.insert(0, field_names)
|
result.insert(0, field_names)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||||
|
"""Execute a SQL command and return a string representing the results."""
|
||||||
|
print("SQL:" + command)
|
||||||
|
if not command:
|
||||||
|
return []
|
||||||
|
parsed, ttype, sql_type = self.__sql_parse(command)
|
||||||
|
if ttype == sqlparse.tokens.DML:
|
||||||
|
if sql_type == "SELECT":
|
||||||
|
return self.__query(session, command, fetch)
|
||||||
|
else:
|
||||||
|
self.__write(session, command)
|
||||||
|
select_sql = self.convert_sql_write_to_select(command)
|
||||||
|
print(f"write result query:{select_sql}")
|
||||||
|
return self.__query(session, select_sql)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"DDL execution determines whether to enable through configuration ")
|
||||||
|
cursor = session.execute(text(command))
|
||||||
|
session.commit()
|
||||||
|
if cursor.returns_rows:
|
||||||
|
result = cursor.fetchall()
|
||||||
|
field_names = tuple(i[0:] for i in cursor.keys())
|
||||||
|
result = list(result)
|
||||||
|
result.insert(0, field_names)
|
||||||
|
print("DDL Result:" + str(result))
|
||||||
|
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||||
"""Execute a SQL command and return a string representing the results.
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
@ -315,3 +372,60 @@ class Database:
|
|||||||
for d in results
|
for d in results
|
||||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def convert_sql_write_to_select(self, write_sql):
|
||||||
|
"""
|
||||||
|
SQL classification processing
|
||||||
|
author:xiangh8
|
||||||
|
Args:
|
||||||
|
sql:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 将SQL命令转换为小写,并按空格拆分
|
||||||
|
parts = write_sql.lower().split()
|
||||||
|
# 获取命令类型(insert, delete, update)
|
||||||
|
cmd_type = parts[0]
|
||||||
|
|
||||||
|
# 根据命令类型进行处理
|
||||||
|
if cmd_type == 'insert':
|
||||||
|
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower())
|
||||||
|
if match:
|
||||||
|
table_name, columns, values = match.groups()
|
||||||
|
# 将字段列表和值列表分割为单独的字段和值
|
||||||
|
columns = columns.split(',')
|
||||||
|
values = values.split(',')
|
||||||
|
# 构造 WHERE 子句
|
||||||
|
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)])
|
||||||
|
return f'SELECT * FROM {table_name} WHERE {where_clause}'
|
||||||
|
|
||||||
|
elif cmd_type == 'delete':
|
||||||
|
table_name = parts[2] # delete from <table_name> ...
|
||||||
|
# 返回一个select语句,它选择该表的所有数据
|
||||||
|
return f'SELECT * FROM {table_name}'
|
||||||
|
|
||||||
|
elif cmd_type == 'update':
|
||||||
|
table_name = parts[1]
|
||||||
|
set_idx = parts.index('set')
|
||||||
|
where_idx = parts.index('where')
|
||||||
|
# 截取 `set` 子句中的字段名
|
||||||
|
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip()
|
||||||
|
# 截取 `where` 之后的条件语句
|
||||||
|
where_clause = ' '.join(parts[where_idx + 1:])
|
||||||
|
# 返回一个select语句,它选择更新的数据
|
||||||
|
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}'
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||||
|
|
||||||
|
def __sql_parse(self, sql):
|
||||||
|
sql = sql.strip()
|
||||||
|
parsed = sqlparse.parse(sql)[0]
|
||||||
|
sql_type = parsed.get_type()
|
||||||
|
|
||||||
|
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||||
|
ttype = first_token.ttype
|
||||||
|
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
||||||
|
return parsed, ttype, sql_type
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,6 +102,9 @@ class Config(metaclass=Singleton):
|
|||||||
self.plugins_denylist = plugins_denylist.split(",")
|
self.plugins_denylist = plugins_denylist.split(",")
|
||||||
else:
|
else:
|
||||||
self.plugins_denylist = []
|
self.plugins_denylist = []
|
||||||
|
### Native SQL Execution Capability Control Configuration
|
||||||
|
self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True"
|
||||||
|
self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True"
|
||||||
|
|
||||||
### Local database connection configuration
|
### Local database connection configuration
|
||||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||||
|
@ -121,10 +121,10 @@ class BaseOutputParser(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
cleaned_output = model_out_text.rstrip()
|
cleaned_output = model_out_text.rstrip()
|
||||||
if "```json" in cleaned_output:
|
# if "```json" in cleaned_output:
|
||||||
_, cleaned_output = cleaned_output.split("```json")
|
# _, cleaned_output = cleaned_output.split("```json")
|
||||||
if "```" in cleaned_output:
|
# if "```" in cleaned_output:
|
||||||
cleaned_output, _ = cleaned_output.split("```")
|
# cleaned_output, _ = cleaned_output.split("```")
|
||||||
if cleaned_output.startswith("```json"):
|
if cleaned_output.startswith("```json"):
|
||||||
cleaned_output = cleaned_output[len("```json"):]
|
cleaned_output = cleaned_output[len("```json"):]
|
||||||
if cleaned_output.startswith("```"):
|
if cleaned_output.startswith("```"):
|
||||||
|
@ -51,7 +51,33 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}"
|
db = CFG.local_db
|
||||||
ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '')
|
connect = db.get_session("gpt-user")
|
||||||
print(ss)
|
|
||||||
json.loads(ss)
|
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
|
||||||
|
FROM `gpt-user`.users
|
||||||
|
WHERE user_name='test1';
|
||||||
|
""")
|
||||||
|
|
||||||
|
print(str(db.get_session_db(connect)))
|
||||||
|
print(str(results))
|
||||||
|
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
|
||||||
|
FROM `gpt-user`.users
|
||||||
|
WHERE user_name='test2';
|
||||||
|
""")
|
||||||
|
print(str(db.get_session_db(connect)))
|
||||||
|
print(str(results))
|
||||||
|
|
||||||
|
results = db.run(connect, """INSERT INTO `gpt-user`.users
|
||||||
|
(user_name, phone, email, city, create_time, last_login_time)
|
||||||
|
VALUES('test4', '23', NULL, '成都', '2023-05-09 09:09:09', NULL);
|
||||||
|
""")
|
||||||
|
print(str(db.get_session_db(connect)))
|
||||||
|
print(str(results))
|
||||||
|
|
||||||
|
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
|
||||||
|
FROM `gpt-user`.users
|
||||||
|
WHERE user_name='test3';
|
||||||
|
""")
|
||||||
|
print(str(db.get_session_db(connect)))
|
||||||
|
print(str(results))
|
@ -30,6 +30,8 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
def parse_view_response(self, speak, data) -> str:
|
def parse_view_response(self, speak, data) -> str:
|
||||||
### tool out data to table view
|
### tool out data to table view
|
||||||
|
if len(data) <= 1:
|
||||||
|
data.insert(0, ["result"])
|
||||||
df = pd.DataFrame(data[1:], columns=data[0])
|
df = pd.DataFrame(data[1:], columns=data[0])
|
||||||
table_style = """<style>
|
table_style = """<style>
|
||||||
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
|
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
|
||||||
|
@ -15,6 +15,7 @@ _DEFAULT_TEMPLATE = """
|
|||||||
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||||
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
||||||
Use as few tables as possible when querying.
|
Use as few tables as possible when querying.
|
||||||
|
When generating insert, delete, update, or replace SQL, please make sure to use the data given by the human, and cannot use any unknown data. If you do not get enough information, speak to user: I don’t have enough data complete your request.
|
||||||
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -50,7 +50,4 @@ class ChatWithDbQA(BaseChat):
|
|||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
if self.auto_execute:
|
return prompt_response
|
||||||
return self.database.run(self.db_connect, prompt_response.sql)
|
|
||||||
else:
|
|
||||||
return prompt_response
|
|
||||||
|
@ -333,9 +333,6 @@ def http_bot(
|
|||||||
state.messages[-1][-1] = "Error:" + str(e)
|
state.messages[-1][-1] = "Error:" + str(e)
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
if state.messages[-1][-1].endwith("▌"):
|
|
||||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
|
||||||
|
|
||||||
block_css = (
|
block_css = (
|
||||||
code_highlight_css
|
code_highlight_css
|
||||||
|
@ -15,6 +15,7 @@ frozenlist==1.3.3
|
|||||||
huggingface-hub==0.13.4
|
huggingface-hub==0.13.4
|
||||||
importlib-resources==5.12.0
|
importlib-resources==5.12.0
|
||||||
|
|
||||||
|
sqlparse==0.4.4
|
||||||
kiwisolver==1.4.4
|
kiwisolver==1.4.4
|
||||||
matplotlib==3.7.0
|
matplotlib==3.7.0
|
||||||
multidict==6.0.4
|
multidict==6.0.4
|
||||||
|
Loading…
Reference in New Issue
Block a user