diff --git a/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_doris.py b/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_doris.py index e8833a027..deec36c0c 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_doris.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_doris.py @@ -1,11 +1,13 @@ """Doris connector.""" +import weakref from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast from urllib.parse import quote from urllib.parse import quote_plus as urlquote -from sqlalchemy import text +from sqlalchemy import MetaData, inspect, text +from sqlalchemy.orm import scoped_session, sessionmaker from dbgpt.core.awel.flow import ( TAGS_ORDER_HIGH, @@ -32,9 +34,11 @@ class DorisParameters(RDBMSDatasourceParameters): __type__ = "doris" driver: str = field( - default="doris", + default="mysql+pymysql", metadata={ - "help": _("Driver name for Doris, default is doris."), + "help": _( + "Driver name for Doris, default is mysql+pymysql (MySQL compatible)." + ), }, ) @@ -46,9 +50,56 @@ class DorisParameters(RDBMSDatasourceParameters): class DorisConnector(RDBMSConnector): """Doris connector.""" - driver = "doris" + driver = "mysql+pymysql" db_type = "doris" - db_dialect = "doris" + db_dialect = "mysql" + + def __init__( + self, + engine, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 3, + indexes_in_table_info: bool = False, + custom_table_info: Optional[Dict[str, str]] = None, + view_support: bool = False, + ): + """Initialize Doris connector without triggering reflection. + + Override parent __init__ to avoid automatic metadata.reflect() call + which causes issues with Doris data type parsing. + """ + # Initialize basic attributes (copied from parent but without reflect) + self._is_closed = False + self._engine = engine + self._schema = schema + if include_tables and ignore_tables: + raise ValueError("Cannot specify both include_tables and ignore_tables") + + if not custom_table_info: + custom_table_info = {} + + self._inspector = inspect(engine) + session_factory = sessionmaker(bind=engine) + Session_Manages = scoped_session(session_factory) + self._db_sessions = Session_Manages + self._sessions = weakref.WeakSet() + + self.view_support = view_support + self._usable_tables = set() + self._include_tables = set() + self._ignore_tables = set() + self._custom_table_info = custom_table_info + self._sample_rows_in_table_info = sample_rows_in_table_info + self._indexes_in_table_info = indexes_in_table_info + + # NOT call reflect() to avoid Doris type parsing issues + # self._metadata = metadata or MetaData() + # self._metadata.reflect(bind=self._engine) + + self._all_tables = set(self._sync_tables_from_db()) @classmethod def param_class(cls) -> Type[DorisParameters]: @@ -83,7 +134,6 @@ class DorisConnector(RDBMSConnector): ) table_results = set(row[0] for row in table_results) # noqa: C401 self._all_tables = table_results - self._metadata.reflect(bind=self._engine) return self._all_tables def get_grants(self): @@ -158,9 +208,22 @@ class DorisConnector(RDBMSConnector): (field[0], field[1], field[2], field[3], field[4]) for field in fields ] - def get_charset(self): + def get_charset(self) -> str: """Get character_set.""" - return "utf-8" + with self.session_scope() as session: + cursor = session.execute( + text( + """ + SELECT DEFAULT_CHARACTER_SET_NAME + FROM information_schema.SCHEMATA + where SCHEMA_NAME=database() + """ + ) + ) + ans = cursor.fetchall() + if ans: + return ans[0][0] + return "" def get_show_create_table(self, table_name) -> str: """Get show create table.""" @@ -260,3 +323,156 @@ class DorisConnector(RDBMSConnector): cursor = session.execute(text(f"SHOW INDEX FROM {table_name}")) indexes = cursor.fetchall() return [(index[2], index[4]) for index in indexes] + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables. + + Override parent method to avoid dependency on metadata.reflect() + which causes issues with Doris data type parsing. + Uses direct SQL queries to get table information. + """ + all_table_names = list(self.get_usable_table_names()) + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + raise ValueError(f"table_names {missing_tables} not found in database") + all_table_names = table_names + + if not all_table_names: + return "" + + tables = [] + for table_name in all_table_names: + if self._custom_table_info and table_name in self._custom_table_info: + tables.append(self._custom_table_info[table_name]) + continue + + # Build table info using direct SQL queries + table_info = self._build_table_info_for_doris(table_name) + tables.append(table_info) + + return "\n\n".join(tables) + + def _build_table_info_for_doris(self, table_name: str) -> str: + """Build table information for Doris using direct SQL queries.""" + try: + with self.session_scope() as session: + # Get table structure information + cursor = session.execute( + text( + "SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE, " + "COLUMN_DEFAULT, COLUMN_COMMENT " + "FROM information_schema.columns " + f'WHERE TABLE_NAME="{table_name}" AND TABLE_SCHEMA=database() ' + "ORDER BY ORDINAL_POSITION" + ) + ) + columns = cursor.fetchall() + + if not columns: + return f"-- Table {table_name} not found" + + # Build CREATE TABLE statement + table_info = f"CREATE TABLE {table_name} (\n" + column_definitions = [] + + for col in columns: + col_name, col_type, is_nullable, col_default, col_comment = col + col_def = f" `{col_name}` {col_type}" + + if is_nullable == "NO": + col_def += " NOT NULL" + + if col_default is not None: + col_def += f" DEFAULT {col_default}" + + if col_comment: + col_def += f" COMMENT '{col_comment}'" + + column_definitions.append(col_def) + + table_info += ",\n".join(column_definitions) + table_info += "\n)" + + # Get table comment if available + try: + comment_cursor = session.execute( + text( + "SELECT TABLE_COMMENT FROM information_schema.tables " + f'WHERE TABLE_NAME="{table_name}"' + f" AND TABLE_SCHEMA=database()" + ) + ) + table_comment = comment_cursor.fetchone() + if table_comment and table_comment[0]: + table_info += f" COMMENT='{table_comment[0]}'" + except Exception: + pass # Ignore comment retrieval errors + + # Add sample rows if configured + if self._sample_rows_in_table_info > 0: + table_info += self._get_sample_rows_for_doris(table_name) + + # Add index information if configured + if self._indexes_in_table_info: + table_info += self._get_indexes_info_for_doris(table_name) + + return table_info + + except Exception as e: + return f"-- Error getting info for table {table_name}: {str(e)}" + + def _get_sample_rows_for_doris(self, table_name: str) -> str: + """Get sample rows for Doris table.""" + try: + with self.session_scope() as session: + cursor = session.execute( + text( + f"SELECT * FROM {table_name} LIMIT " + f"{self._sample_rows_in_table_info}" + ) + ) + rows = cursor.fetchall() + + if not rows: + return "" + + # Get column names + column_names = list(cursor.keys()) + columns_str = "\t".join(column_names) + + # Format sample rows + sample_rows_str = "\n".join( + [ + "\t".join( + [ + str(val)[:100] if val is not None else "NULL" + for val in row + ] + ) + for row in rows + ] + ) + + return ( + f"\n\n/*\n{self._sample_rows_in_table_info} rows from " + f"{table_name} table:\n{columns_str}\n{sample_rows_str}\n*/" + ) + + except Exception: + return f"\n\n/*\nError getting sample rows for table {table_name}\n*/" + + def _get_indexes_info_for_doris(self, table_name: str) -> str: + """Get index information for Doris table.""" + try: + indexes = self.get_indexes(table_name) + if not indexes: + return f"\n\n/*\nTable Indexes for {table_name}:\nNo indexes found\n*/" + + indexes_str = "\n".join( + [f"Index: {idx[0]}, Column: {idx[1]}" for idx in indexes] + ) + return f"\n\n/*\nTable Indexes for {table_name}:\n{indexes_str}\n*/" + + except Exception: + return f"\n\n/*\nError getting indexes for table {table_name}\n*/" diff --git a/tests/intetration_tests/datasource/test_conn_doris.py b/tests/intetration_tests/datasource/test_conn_doris.py index 33364f413..6bebade00 100644 --- a/tests/intetration_tests/datasource/test_conn_doris.py +++ b/tests/intetration_tests/datasource/test_conn_doris.py @@ -1,13 +1,76 @@ """ Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_doris.py + + docker run -it -d --name doris -p 8030:8030 -p 8040:8040 -p 9030:9030 -p 8048:8048 apache/doris:doris-all-in-one-2.1.0 + + 9030: The MySQL protocol port of FE. + + Connection: mysql -uadmin -P9030 -h127.0.0.1 + """ import pytest -from dbgpt.datasource.rdbms.conn_doris import DorisConnector +from dbgpt_ext.datasource.rdbms.conn_doris import DorisConnector + +_create_table_sql = """ + CREATE TABLE IF NOT EXISTS `test` ( + `id` int(11) DEFAULT NULL, + `name` VARCHAR(200) DEFAULT NULL, + `sex` VARCHAR(200) DEFAULT NULL, + INDEX idx_name (`name`) USING INVERTED + ) UNIQUE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 10 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ @pytest.fixture def db(): - conn = DorisConnector.from_uri_db("localhost", 9030, "root", "", "test") + conn = DorisConnector.from_uri_db("localhost", 9030, "admin", "", "test") yield conn + + +def test_get_usable_table_names(db): + db.run(_create_table_sql) + print(db._sync_tables_from_db()) + assert list(db.get_usable_table_names()) == ['test'] + + +def test_get_table_info(db): + db.run(_create_table_sql) + print(db._sync_tables_from_db()) + assert "CREATE TABLE test" in db.get_table_info() + + +def test_run_no_throw(db): + assert db.run_no_throw("this is a error sql") == [] + + +def test_get_index(db): + db.run(_create_table_sql) + assert db.get_indexes("test") == [('idx_name', 'name')] + + +def test_get_fields(db): + db.run(_create_table_sql) + assert list(db.get_fields("test")[0])[0] == "id" + + +def test_get_charset(db): + assert db.get_charset() == "utf8mb4" + + +def test_get_collation(db): + assert ( + db.get_collation() == "utf8mb4_0900_bin" + or db.get_collation() == "utf8mb4_general_ci" + ) + +def test_get_users(db): + assert db.get_users() == [] + +def test_get_database_lists(db): + assert "test" in db.get_database_names()