chat with plugin bug fix

This commit is contained in:
yhjun1026
2023-06-01 14:02:56 +08:00
parent ced9b581fc
commit 6c2ab298a0
10 changed files with 194 additions and 21 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import sqlparse
import regex as re
import warnings
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
@@ -18,7 +19,6 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
@@ -143,6 +143,12 @@ class Database:
)
return self.get_usable_table_names()
def get_session_db(self, connect):
sql = text(f"select DATABASE()")
cursor = connect.execute(sql)
result = cursor.fetchone()[0]
return result
def get_session(self, db_name: str):
session = self._db_sessions()
@@ -275,10 +281,31 @@ class Database:
"""Format the error message"""
return f"Error: {e}"
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("sql run:" + command)
cursor = session.execute(text(command))
def __write(self, session, write_sql):
print(f"Write[{write_sql}]")
db_cache = self.get_session_db(session)
result = session.execute(text(write_sql))
session.commit()
#TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self,session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
@@ -292,6 +319,36 @@ class Database:
result.insert(0, field_names)
return result
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
if not command:
return []
parsed, ttype, sql_type = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(session, command, fetch)
else:
self.__write(session, command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(session, select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command))
session.commit()
if cursor.returns_rows:
result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
return result
else:
return []
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.
@@ -315,3 +372,60 @@ class Database:
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]
def convert_sql_write_to_select(self, write_sql):
"""
SQL classification processing
author:xiangh8
Args:
sql:
Returns:
"""
# 将SQL命令转换为小写并按空格拆分
parts = write_sql.lower().split()
# 获取命令类型insert, delete, update
cmd_type = parts[0]
# 根据命令类型进行处理
if cmd_type == 'insert':
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower())
if match:
table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值
columns = columns.split(',')
values = values.split(',')
# 构造 WHERE 子句
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)])
return f'SELECT * FROM {table_name} WHERE {where_clause}'
elif cmd_type == 'delete':
table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据
return f'SELECT * FROM {table_name}'
elif cmd_type == 'update':
table_name = parts[1]
set_idx = parts.index('set')
where_idx = parts.index('where')
# 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip()
# 截取 `where` 之后的条件语句
where_clause = ' '.join(parts[where_idx + 1:])
# 返回一个select语句它选择更新的数据
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}'
else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
def __sql_parse(self, sql):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
return parsed, ttype, sql_type