chat with plugin bug fix

This commit is contained in:
yhjun1026 2023-06-01 14:02:56 +08:00
parent ced9b581fc
commit 6c2ab298a0
10 changed files with 194 additions and 21 deletions

View File

@ -0,0 +1,32 @@
from collections import OrderedDict
from collections import deque
class FixedSizeDict(OrderedDict):
def __init__(self, max_size):
super().__init__()
self.max_size = max_size
def __setitem__(self, key, value):
if len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)
class FixedSizeList:
def __init__(self, max_size):
self.max_size = max_size
self.list = deque(maxlen=max_size)
def append(self, value):
self.list.append(value)
def __getitem__(self, index):
return self.list[index]
def __setitem__(self, index, value):
self.list[index] = value
def __len__(self):
return len(self.list)
def __str__(self):
return str(list(self.list))

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import sqlparse
import regex as re
import warnings import warnings
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra from pydantic import BaseModel, Field, root_validator, validator, Extra
@ -18,7 +19,6 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return ( return (
f'Name: {index["name"]}, Unique: {index["unique"]},' f'Name: {index["name"]}, Unique: {index["unique"]},'
@ -143,6 +143,12 @@ class Database:
) )
return self.get_usable_table_names() return self.get_usable_table_names()
def get_session_db(self, connect):
sql = text(f"select DATABASE()")
cursor = connect.execute(sql)
result = cursor.fetchone()[0]
return result
def get_session(self, db_name: str): def get_session(self, db_name: str):
session = self._db_sessions() session = self._db_sessions()
@ -275,10 +281,31 @@ class Database:
"""Format the error message""" """Format the error message"""
return f"Error: {e}" return f"Error: {e}"
def run(self, session, command: str, fetch: str = "all") -> List: def __write(self, session, write_sql):
"""Execute a SQL command and return a string representing the results.""" print(f"Write[{write_sql}]")
print("sql run:" + command) db_cache = self.get_session_db(session)
cursor = session.execute(text(command)) result = session.execute(text(write_sql))
session.commit()
#TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self,session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows: if cursor.returns_rows:
if fetch == "all": if fetch == "all":
result = cursor.fetchall() result = cursor.fetchall()
@ -292,6 +319,36 @@ class Database:
result.insert(0, field_names) result.insert(0, field_names)
return result return result
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
if not command:
return []
parsed, ttype, sql_type = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(session, command, fetch)
else:
self.__write(session, command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(session, select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command))
session.commit()
if cursor.returns_rows:
result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
return result
else:
return []
def run_no_throw(self, session, command: str, fetch: str = "all") -> List: def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results. """Execute a SQL command and return a string representing the results.
@ -315,3 +372,60 @@ class Database:
for d in results for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
] ]
def convert_sql_write_to_select(self, write_sql):
"""
SQL classification processing
author:xiangh8
Args:
sql:
Returns:
"""
# 将SQL命令转换为小写并按空格拆分
parts = write_sql.lower().split()
# 获取命令类型insert, delete, update
cmd_type = parts[0]
# 根据命令类型进行处理
if cmd_type == 'insert':
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower())
if match:
table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值
columns = columns.split(',')
values = values.split(',')
# 构造 WHERE 子句
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)])
return f'SELECT * FROM {table_name} WHERE {where_clause}'
elif cmd_type == 'delete':
table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据
return f'SELECT * FROM {table_name}'
elif cmd_type == 'update':
table_name = parts[1]
set_idx = parts.index('set')
where_idx = parts.index('where')
# 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip()
# 截取 `where` 之后的条件语句
where_clause = ' '.join(parts[where_idx + 1:])
# 返回一个select语句它选择更新的数据
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}'
else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
def __sql_parse(self, sql):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
return parsed, ttype, sql_type

View File

@ -102,6 +102,9 @@ class Config(metaclass=Singleton):
self.plugins_denylist = plugins_denylist.split(",") self.plugins_denylist = plugins_denylist.split(",")
else: else:
self.plugins_denylist = [] self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True"
self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True"
### Local database connection configuration ### Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")

View File

@ -121,10 +121,10 @@ class BaseOutputParser(ABC):
""" """
cleaned_output = model_out_text.rstrip() cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output: # if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json") # _, cleaned_output = cleaned_output.split("```json")
if "```" in cleaned_output: # if "```" in cleaned_output:
cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):] cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):

View File

@ -51,7 +51,33 @@ class ChatWithDbAutoExecute(BaseChat):
if __name__ == "__main__": if __name__ == "__main__":
ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}" db = CFG.local_db
ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '') connect = db.get_session("gpt-user")
print(ss)
json.loads(ss) results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test1';
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test2';
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """INSERT INTO `gpt-user`.users
(user_name, phone, email, city, create_time, last_login_time)
VALUES('test4', '23', NULL, '成都', '2023-05-09 09:09:09', NULL);
""")
print(str(db.get_session_db(connect)))
print(str(results))
results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time
FROM `gpt-user`.users
WHERE user_name='test3';
""")
print(str(db.get_session_db(connect)))
print(str(results))

View File

@ -30,6 +30,8 @@ class DbChatOutputParser(BaseOutputParser):
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:
### tool out data to table view ### tool out data to table view
if len(data) <= 1:
data.insert(0, ["result"])
df = pd.DataFrame(data[1:], columns=data[0]) df = pd.DataFrame(data[1:], columns=data[0])
table_style = """<style> table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444} table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}

View File

@ -15,6 +15,7 @@ _DEFAULT_TEMPLATE = """
You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
Use as few tables as possible when querying. Use as few tables as possible when querying.
When generating insert, delete, update, or replace SQL, please make sure to use the data given by the human, and cannot use any unknown data. If you do not get enough information, speak to user: I dont have enough data complete your request.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
""" """

View File

@ -50,7 +50,4 @@ class ChatWithDbQA(BaseChat):
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):
if self.auto_execute: return prompt_response
return self.database.run(self.db_connect, prompt_response.sql)
else:
return prompt_response

View File

@ -333,9 +333,6 @@ def http_bot(
state.messages[-1][-1] = "Error:" + str(e) state.messages[-1][-1] = "Error:" + str(e)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
if state.messages[-1][-1].endwith(""):
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
block_css = ( block_css = (
code_highlight_css code_highlight_css

View File

@ -15,6 +15,7 @@ frozenlist==1.3.3
huggingface-hub==0.13.4 huggingface-hub==0.13.4
importlib-resources==5.12.0 importlib-resources==5.12.0
sqlparse==0.4.4
kiwisolver==1.4.4 kiwisolver==1.4.4
matplotlib==3.7.0 matplotlib==3.7.0
multidict==6.0.4 multidict==6.0.4