diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/rdbms_connect.py index b57a51978..d2a853589 100644 --- a/pilot/connections/rdbms/rdbms_connect.py +++ b/pilot/connections/rdbms/rdbms_connect.py @@ -252,7 +252,7 @@ class RDBMSDatabase(BaseConnect): def __write(self, session, write_sql): print(f"Write[{write_sql}]") - db_cache = self.get_session_db(session) + db_cache = self._engine.url.database result = session.execute(text(write_sql)) session.commit() # TODO Subsequent optimization of dynamically specified database submission loss target problem @@ -318,7 +318,7 @@ class RDBMSDatabase(BaseConnect): print("SQL:" + command) if not command: return [] - parsed, ttype, sql_type = self.__sql_parse(command) + parsed, ttype, sql_type, table_name = self.__sql_parse(command) if ttype == sqlparse.tokens.DML: if sql_type == "SELECT": return self.__query(session, command, fetch) @@ -339,10 +339,10 @@ class RDBMSDatabase(BaseConnect): result.insert(0, field_names) print("DDL Result:" + str(result)) if not result: - return self.__query(session, "SHOW COLUMNS FROM test") + return self.__query(session, f"SHOW COLUMNS FROM {table_name}") return result else: - return self.__query(session, "SHOW COLUMNS FROM test") + return self.__query(session, f"SHOW COLUMNS FROM {table_name}") def run_no_throw(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results. @@ -424,11 +424,12 @@ class RDBMSDatabase(BaseConnect): sql = sql.strip() parsed = sqlparse.parse(sql)[0] sql_type = parsed.get_type() + table_name = parsed.get_name() first_token = parsed.token_first(skip_ws=True, skip_cm=False) ttype = first_token.ttype - print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}") - return parsed, ttype, sql_type + print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}") + return parsed, ttype, sql_type, table_name def get_indexes(self, table_name): """Get table indexes about specified table."""