diff --git a/pilot/common/custom_data_structure.py b/pilot/common/custom_data_structure.py new file mode 100644 index 000000000..ca0892528 --- /dev/null +++ b/pilot/common/custom_data_structure.py @@ -0,0 +1,32 @@ +from collections import OrderedDict +from collections import deque + +class FixedSizeDict(OrderedDict): + def __init__(self, max_size): + super().__init__() + self.max_size = max_size + + def __setitem__(self, key, value): + if len(self) >= self.max_size: + self.popitem(last=False) + super().__setitem__(key, value) + +class FixedSizeList: + def __init__(self, max_size): + self.max_size = max_size + self.list = deque(maxlen=max_size) + + def append(self, value): + self.list.append(value) + + def __getitem__(self, index): + return self.list[index] + + def __setitem__(self, index, value): + self.list[index] = value + + def __len__(self): + return len(self.list) + + def __str__(self): + return str(list(self.list)) \ No newline at end of file diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index c3ac5bdc6..0c5fcb313 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -1,5 +1,6 @@ from __future__ import annotations - +import sqlparse +import regex as re import warnings from typing import Any, Iterable, List, Optional from pydantic import BaseModel, Field, root_validator, validator, Extra @@ -18,7 +19,6 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable from sqlalchemy.orm import sessionmaker, scoped_session - def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: return ( f'Name: {index["name"]}, Unique: {index["unique"]},' @@ -143,6 +143,12 @@ class Database: ) return self.get_usable_table_names() + def get_session_db(self, connect): + sql = text(f"select DATABASE()") + cursor = connect.execute(sql) + result = cursor.fetchone()[0] + return result + def get_session(self, db_name: str): session = self._db_sessions() @@ -275,10 +281,31 @@ class Database: """Format the error message""" return f"Error: {e}" - def run(self, session, command: str, fetch: str = "all") -> List: - """Execute a SQL command and return a string representing the results.""" - print("sql run:" + command) - cursor = session.execute(text(command)) + def __write(self, session, write_sql): + print(f"Write[{write_sql}]") + db_cache = self.get_session_db(session) + result = session.execute(text(write_sql)) + session.commit() + #TODO Subsequent optimization of dynamically specified database submission loss target problem + session.execute(text(f"use `{db_cache}`")) + print(f"SQL[{write_sql}], result:{result.rowcount}") + return result.rowcount + + def __query(self,session, query, fetch: str = "all"): + """ + only for query + Args: + session: + query: + fetch: + + Returns: + + """ + print(f"Query[{query}]") + if not query: + return [] + cursor = session.execute(text(query)) if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() @@ -292,6 +319,36 @@ class Database: result.insert(0, field_names) return result + def run(self, session, command: str, fetch: str = "all") -> List: + """Execute a SQL command and return a string representing the results.""" + print("SQL:" + command) + if not command: + return [] + parsed, ttype, sql_type = self.__sql_parse(command) + if ttype == sqlparse.tokens.DML: + if sql_type == "SELECT": + return self.__query(session, command, fetch) + else: + self.__write(session, command) + select_sql = self.convert_sql_write_to_select(command) + print(f"write result query:{select_sql}") + return self.__query(session, select_sql) + + else: + print(f"DDL execution determines whether to enable through configuration ") + cursor = session.execute(text(command)) + session.commit() + if cursor.returns_rows: + result = cursor.fetchall() + field_names = tuple(i[0:] for i in cursor.keys()) + result = list(result) + result.insert(0, field_names) + print("DDL Result:" + str(result)) + + return result + else: + return [] + def run_no_throw(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results. @@ -315,3 +372,60 @@ class Database: for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] ] + + def convert_sql_write_to_select(self, write_sql): + """ + SQL classification processing + author:xiangh8 + Args: + sql: + + Returns: + + """ + # 将SQL命令转换为小写,并按空格拆分 + parts = write_sql.lower().split() + # 获取命令类型(insert, delete, update) + cmd_type = parts[0] + + # 根据命令类型进行处理 + if cmd_type == 'insert': + match = re.match(r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()) + if match: + table_name, columns, values = match.groups() + # 将字段列表和值列表分割为单独的字段和值 + columns = columns.split(',') + values = values.split(',') + # 构造 WHERE 子句 + where_clause = " AND ".join([f"{col.strip()}={val.strip()}" for col, val in zip(columns, values)]) + return f'SELECT * FROM {table_name} WHERE {where_clause}' + + elif cmd_type == 'delete': + table_name = parts[2] # delete from ... + # 返回一个select语句,它选择该表的所有数据 + return f'SELECT * FROM {table_name}' + + elif cmd_type == 'update': + table_name = parts[1] + set_idx = parts.index('set') + where_idx = parts.index('where') + # 截取 `set` 子句中的字段名 + set_clause = parts[set_idx + 1: where_idx][0].split('=')[0].strip() + # 截取 `where` 之后的条件语句 + where_clause = ' '.join(parts[where_idx + 1:]) + # 返回一个select语句,它选择更新的数据 + return f'SELECT {set_clause} FROM {table_name} WHERE {where_clause}' + else: + raise ValueError(f"Unsupported SQL command type: {cmd_type}") + + def __sql_parse(self, sql): + sql = sql.strip() + parsed = sqlparse.parse(sql)[0] + sql_type = parsed.get_type() + + first_token = parsed.token_first(skip_ws=True, skip_cm=False) + ttype = first_token.ttype + print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}") + return parsed, ttype, sql_type + + diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 9462300fa..8618651e4 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -102,6 +102,9 @@ class Config(metaclass=Singleton): self.plugins_denylist = plugins_denylist.split(",") else: self.plugins_denylist = [] + ### Native SQL Execution Capability Control Configuration + self.NATIVE_SQL_CAN_RUN_DDL = os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") =="True" + self.NATIVE_SQL_CAN_RUN_WRITE = os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") =="True" ### Local database connection configuration self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index fee3eda37..9beb8b5f5 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -121,10 +121,10 @@ class BaseOutputParser(ABC): """ cleaned_output = model_out_text.rstrip() - if "```json" in cleaned_output: - _, cleaned_output = cleaned_output.split("```json") - if "```" in cleaned_output: - cleaned_output, _ = cleaned_output.split("```") + # if "```json" in cleaned_output: + # _, cleaned_output = cleaned_output.split("```json") + # if "```" in cleaned_output: + # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): cleaned_output = cleaned_output[len("```json"):] if cleaned_output.startswith("```"): diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 2b8918fde..254fb33d2 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -51,7 +51,33 @@ class ChatWithDbAutoExecute(BaseChat): if __name__ == "__main__": - ss = "{\n \"thoughts\": \"to get the user's city, we need to join the users table with the tran_order table using the user_name column. we also need to filter the results to only show orders for user test1.\",\n \"sql\": \"select o.order_id, o.product_name, u.city from tran_order o join users u on o.user_name = u.user_name where o.user_name = 'test1' limit 5\"\n}" - ss.strip().replace('\n', '').replace('\\n', '').replace('', '').replace(' ', '').replace('\\', '').replace('\\', '') - print(ss) - json.loads(ss) \ No newline at end of file + db = CFG.local_db + connect = db.get_session("gpt-user") + + results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time + FROM `gpt-user`.users + WHERE user_name='test1'; + """) + + print(str(db.get_session_db(connect))) + print(str(results)) + results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time + FROM `gpt-user`.users + WHERE user_name='test2'; + """) + print(str(db.get_session_db(connect))) + print(str(results)) + + results = db.run(connect, """INSERT INTO `gpt-user`.users + (user_name, phone, email, city, create_time, last_login_time) + VALUES('test4', '23', NULL, '成都', '2023-05-09 09:09:09', NULL); + """) + print(str(db.get_session_db(connect))) + print(str(results)) + + results = db.run(connect, """SELECT user_name, phone, email, city, create_time, last_login_time + FROM `gpt-user`.users + WHERE user_name='test3'; + """) + print(str(db.get_session_db(connect))) + print(str(results)) \ No newline at end of file diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index cb059feb8..66b3520fd 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -30,6 +30,8 @@ class DbChatOutputParser(BaseOutputParser): def parse_view_response(self, speak, data) -> str: ### tool out data to table view + if len(data) <= 1: + data.insert(0, ["result"]) df = pd.DataFrame(data[1:], columns=data[0]) table_style = """