diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py index 7140fd1a3..37fc446cb 100644 --- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py +++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py @@ -54,7 +54,7 @@ async def get_editor_tables( for table in tables: table_node: DataNode = DataNode(title=table, key=table, type="table") db_node.children.append(table_node) - fields = db_conn.get_fields(table) + fields = db_conn.get_fields(table, db_name) for field in fields: table_node.children.append( DataNode( diff --git a/dbgpt/datasource/rdbms/base.py b/dbgpt/datasource/rdbms/base.py index b9072d7cb..f09a7553a 100644 --- a/dbgpt/datasource/rdbms/base.py +++ b/dbgpt/datasource/rdbms/base.py @@ -532,16 +532,17 @@ class RDBMSConnector(BaseConnector): ans = cursor.fetchall() return ans[0][1] - def get_fields(self, table_name) -> List[Tuple]: + def get_fields(self, table_name, db_name=None) -> List[Tuple]: """Get column fields about specified table.""" session = self._db_sessions() - cursor = session.execute( - text( - "SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, " - "COLUMN_COMMENT from information_schema.COLUMNS where " - f"table_name='{table_name}'".format(table_name) - ) + query = ( + "SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, " + "COLUMN_COMMENT from information_schema.COLUMNS where " + f"table_name='{table_name}'" ) + if db_name is not None: + query += f" AND table_schema='{db_name}'" + cursor = session.execute(text(query)) fields = cursor.fetchall() return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] diff --git a/dbgpt/datasource/rdbms/conn_clickhouse.py b/dbgpt/datasource/rdbms/conn_clickhouse.py index 13659af46..e5253d33e 100644 --- a/dbgpt/datasource/rdbms/conn_clickhouse.py +++ b/dbgpt/datasource/rdbms/conn_clickhouse.py @@ -155,16 +155,17 @@ class ClickhouseConnector(RDBMSConnector): """Return string representation of dialect to use.""" return "" - def get_fields(self, table_name) -> List[Tuple]: + def get_fields(self, table_name, db_name=None) -> List[Tuple]: """Get column fields about specified table.""" session = self.client - _query_sql = f""" SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}' """.format( table_name ) + if db_name is not None: + _query_sql += f" AND database='{db_name}'" with session.query_row_block_stream(_query_sql) as stream: fields = [block for block in stream] # noqa return fields diff --git a/dbgpt/datasource/rdbms/conn_doris.py b/dbgpt/datasource/rdbms/conn_doris.py index 9f1aeece2..a068204a3 100644 --- a/dbgpt/datasource/rdbms/conn_doris.py +++ b/dbgpt/datasource/rdbms/conn_doris.py @@ -100,7 +100,7 @@ class DorisConnector(RDBMSConnector): for field in fields ] - def get_fields(self, table_name) -> List[Tuple]: + def get_fields(self, table_name, db_name=None) -> List[Tuple]: """Get column fields about specified table.""" cursor = self.get_session().execute( text( diff --git a/dbgpt/datasource/rdbms/conn_postgresql.py b/dbgpt/datasource/rdbms/conn_postgresql.py index 4829c725c..fb68fe09e 100644 --- a/dbgpt/datasource/rdbms/conn_postgresql.py +++ b/dbgpt/datasource/rdbms/conn_postgresql.py @@ -96,7 +96,7 @@ class PostgreSQLConnector(RDBMSConnector): logger.warning(f"postgresql get users error: {str(e)}") return [] - def get_fields(self, table_name) -> List[Tuple]: + def get_fields(self, table_name, db_name=None) -> List[Tuple]: """Get column fields about specified table.""" session = self._db_sessions() cursor = session.execute( diff --git a/dbgpt/datasource/rdbms/conn_sqlite.py b/dbgpt/datasource/rdbms/conn_sqlite.py index 116ceba6e..479aa0d83 100644 --- a/dbgpt/datasource/rdbms/conn_sqlite.py +++ b/dbgpt/datasource/rdbms/conn_sqlite.py @@ -55,7 +55,7 @@ class SQLiteConnector(RDBMSConnector): ans = cursor.fetchall() return ans[0][0] - def get_fields(self, table_name) -> List[Tuple]: + def get_fields(self, table_name, db_name=None) -> List[Tuple]: """Get column fields about specified table.""" cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')")) fields = cursor.fetchall() diff --git a/dbgpt/datasource/rdbms/conn_vertica.py b/dbgpt/datasource/rdbms/conn_vertica.py index 0445d8d2c..984f749e4 100644 --- a/dbgpt/datasource/rdbms/conn_vertica.py +++ b/dbgpt/datasource/rdbms/conn_vertica.py @@ -88,7 +88,7 @@ table name should keep its schema name in " logger.warning(f"vertica get users error: {str(e)}") return [] - def get_fields(self, table_name) -> List[Tuple]: + def get_fields(self, table_name, db_name=None) -> List[Tuple]: """Get column fields about specified table.""" session = self._db_sessions() cursor = session.execute(