mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-29 21:49:35 +00:00
chat with plugin bug fix
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlparse
|
||||
import regex as re
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
@@ -18,7 +19,6 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
||||
|
||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
return (
|
||||
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
||||
@@ -143,6 +143,12 @@ class Database:
|
||||
)
|
||||
return self.get_usable_table_names()
|
||||
|
||||
def get_session_db(self, connect):
|
||||
sql = text(f"select DATABASE()")
|
||||
cursor = connect.execute(sql)
|
||||
result = cursor.fetchone()[0]
|
||||
return result
|
||||
|
||||
def get_session(self, db_name: str):
|
||||
session = self._db_sessions()
|
||||
|
||||
@@ -275,10 +281,31 @@ class Database:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("sql run:" + command)
|
||||
cursor = session.execute(text(command))
|
||||
def __write(self, session, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self.get_session_db(session)
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
#TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self,session, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
@@ -292,6 +319,36 @@ class Database:
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
if not command:
|
||||
return []
|
||||
parsed, ttype, sql_type = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self.__query(session, command, fetch)
|
||||
else:
|
||||
self.__write(session, command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query(session, select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
cursor = session.execute(text(command))
|
||||
session.commit()
|
||||
if cursor.returns_rows:
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
|
||||
return result
|
||||
else:
|
||||
return []
|
||||
|
||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
@@ -315,3 +372,60 @@ class Database:
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
||||
def convert_sql_write_to_select(self, write_sql):
|
||||
"""
|
||||
SQL classification processing
|
||||
author:xiangh8
|
||||
Args:
|
||||
sql:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 将SQL命令转换为小写,并按空格拆分
|
||||
parts = write_sql.lower().split()
|
||||
# 获取命令类型(insert, delete, update)
|
||||
cmd_type = parts[0]
|
||||
|
||||
# 根据命令类型进行处理
|
||||
if cmd_type == 'insert':
|
||||
match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower())
|
||||
if match:
|
||||
table_name, columns, values = match.groups()
|
||||
# 将字段列表和值列表分割为单独的字段和值
|
||||
columns = columns.split(',')
|
||||
values = values.split(',')
|
||||
# 构造 WHERE 子句
|
||||
where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)])
|
||||
return f'SELECT * FROM {table_name} WHERE {where_clause}'
|
||||
|
||||
elif cmd_type == 'delete':
|
||||
table_name = parts[2] # delete from <table_name> ...
|
||||
# 返回一个select语句,它选择该表的所有数据
|
||||
return f'SELECT * FROM {table_name}'
|
||||
|
||||
elif cmd_type == 'update':
|
||||
table_name = parts[1]
|
||||
set_idx = parts.index('set')
|
||||
where_idx = parts.index('where')
|
||||
# 截取 `set` 子句中的字段名
|
||||
set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip()
|
||||
# 截取 `where` 之后的条件语句
|
||||
where_clause = ' '.join(parts[where_idx + 1:])
|
||||
# 返回一个select语句,它选择更新的数据
|
||||
return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}'
|
||||
else:
|
||||
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||
|
||||
def __sql_parse(self, sql):
|
||||
sql = sql.strip()
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
sql_type = parsed.get_type()
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
||||
return parsed, ttype, sql_type
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user