From dab39a11804db48915ff04d43c6dac7febe2079e Mon Sep 17 00:00:00 2001 From: yhjun1026 <460342015@qq.com> Date: Tue, 25 Jul 2023 19:00:03 +0800 Subject: [PATCH] DDL run bug fix --- pilot/connections/rdbms/rdbms_connect.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/rdbms_connect.py index a882040cd..4bfc0b891 100644 --- a/pilot/connections/rdbms/rdbms_connect.py +++ b/pilot/connections/rdbms/rdbms_connect.py @@ -253,7 +253,7 @@ class RDBMSDatabase(BaseConnect): def __write(self, session, write_sql): print(f"Write[{write_sql}]") - db_cache = self.get_session_db(session) + db_cache = self._engine.url.database result = session.execute(text(write_sql)) session.commit() # TODO Subsequent optimization of dynamically specified database submission loss target problem @@ -319,7 +319,7 @@ class RDBMSDatabase(BaseConnect): print("SQL:" + command) if not command: return [] - parsed, ttype, sql_type = self.__sql_parse(command) + parsed, ttype, sql_type, table_name = self.__sql_parse(command) if ttype == sqlparse.tokens.DML: if sql_type == "SELECT": return self.__query(session, command, fetch) @@ -340,10 +340,10 @@ class RDBMSDatabase(BaseConnect): result.insert(0, field_names) print("DDL Result:" + str(result)) if not result: - return self.__query(session, "SHOW COLUMNS FROM test") + return self.__query(session, f"SHOW COLUMNS FROM {table_name}") return result else: - return self.__query(session, "SHOW COLUMNS FROM test") + return self.__query(session, f"SHOW COLUMNS FROM {table_name}") def run_no_throw(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results. @@ -425,11 +425,12 @@ class RDBMSDatabase(BaseConnect): sql = sql.strip() parsed = sqlparse.parse(sql)[0] sql_type = parsed.get_type() + table_name = parsed.get_name() first_token = parsed.token_first(skip_ws=True, skip_cm=False) ttype = first_token.ttype - print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}") - return parsed, ttype, sql_type + print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}") + return parsed, ttype, sql_type, table_name def get_indexes(self, table_name): """Get table indexes about specified table."""