fix(datasource): fix doris DB connection use mysql protocol (#2875)

This commit is contained in:
alanchen 2025-08-07 10:34:14 +08:00 committed by GitHub
parent 4e7070b6ee
commit 21896becd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 289 additions and 10 deletions

View File

@ -1,11 +1,13 @@
"""Doris connector."""
import weakref
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from sqlalchemy import text
from sqlalchemy import MetaData, inspect, text
from sqlalchemy.orm import scoped_session, sessionmaker
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
@ -32,9 +34,11 @@ class DorisParameters(RDBMSDatasourceParameters):
__type__ = "doris"
driver: str = field(
default="doris",
default="mysql+pymysql",
metadata={
"help": _("Driver name for Doris, default is doris."),
"help": _(
"Driver name for Doris, default is mysql+pymysql (MySQL compatible)."
),
},
)
@ -46,9 +50,56 @@ class DorisParameters(RDBMSDatasourceParameters):
class DorisConnector(RDBMSConnector):
"""Doris connector."""
driver = "doris"
driver = "mysql+pymysql"
db_type = "doris"
db_dialect = "doris"
db_dialect = "mysql"
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[Dict[str, str]] = None,
view_support: bool = False,
):
"""Initialize Doris connector without triggering reflection.
Override parent __init__ to avoid automatic metadata.reflect() call
which causes issues with Doris data type parsing.
"""
# Initialize basic attributes (copied from parent but without reflect)
self._is_closed = False
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
if not custom_table_info:
custom_table_info = {}
self._inspector = inspect(engine)
session_factory = sessionmaker(bind=engine)
Session_Manages = scoped_session(session_factory)
self._db_sessions = Session_Manages
self._sessions = weakref.WeakSet()
self.view_support = view_support
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = custom_table_info
self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info
# NOT call reflect() to avoid Doris type parsing issues
# self._metadata = metadata or MetaData()
# self._metadata.reflect(bind=self._engine)
self._all_tables = set(self._sync_tables_from_db())
@classmethod
def param_class(cls) -> Type[DorisParameters]:
@ -83,7 +134,6 @@ class DorisConnector(RDBMSConnector):
)
table_results = set(row[0] for row in table_results) # noqa: C401
self._all_tables = table_results
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
@ -158,9 +208,22 @@ class DorisConnector(RDBMSConnector):
(field[0], field[1], field[2], field[3], field[4]) for field in fields
]
def get_charset(self):
def get_charset(self) -> str:
"""Get character_set."""
return "utf-8"
with self.session_scope() as session:
cursor = session.execute(
text(
"""
SELECT DEFAULT_CHARACTER_SET_NAME
FROM information_schema.SCHEMATA
where SCHEMA_NAME=database()
"""
)
)
ans = cursor.fetchall()
if ans:
return ans[0][0]
return ""
def get_show_create_table(self, table_name) -> str:
"""Get show create table."""
@ -260,3 +323,156 @@ class DorisConnector(RDBMSConnector):
cursor = session.execute(text(f"SHOW INDEX FROM {table_name}"))
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Override parent method to avoid dependency on metadata.reflect()
which causes issues with Doris data type parsing.
Uses direct SQL queries to get table information.
"""
all_table_names = list(self.get_usable_table_names())
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
if not all_table_names:
return ""
tables = []
for table_name in all_table_names:
if self._custom_table_info and table_name in self._custom_table_info:
tables.append(self._custom_table_info[table_name])
continue
# Build table info using direct SQL queries
table_info = self._build_table_info_for_doris(table_name)
tables.append(table_info)
return "\n\n".join(tables)
def _build_table_info_for_doris(self, table_name: str) -> str:
"""Build table information for Doris using direct SQL queries."""
try:
with self.session_scope() as session:
# Get table structure information
cursor = session.execute(
text(
"SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE, "
"COLUMN_DEFAULT, COLUMN_COMMENT "
"FROM information_schema.columns "
f'WHERE TABLE_NAME="{table_name}" AND TABLE_SCHEMA=database() '
"ORDER BY ORDINAL_POSITION"
)
)
columns = cursor.fetchall()
if not columns:
return f"-- Table {table_name} not found"
# Build CREATE TABLE statement
table_info = f"CREATE TABLE {table_name} (\n"
column_definitions = []
for col in columns:
col_name, col_type, is_nullable, col_default, col_comment = col
col_def = f" `{col_name}` {col_type}"
if is_nullable == "NO":
col_def += " NOT NULL"
if col_default is not None:
col_def += f" DEFAULT {col_default}"
if col_comment:
col_def += f" COMMENT '{col_comment}'"
column_definitions.append(col_def)
table_info += ",\n".join(column_definitions)
table_info += "\n)"
# Get table comment if available
try:
comment_cursor = session.execute(
text(
"SELECT TABLE_COMMENT FROM information_schema.tables "
f'WHERE TABLE_NAME="{table_name}"'
f" AND TABLE_SCHEMA=database()"
)
)
table_comment = comment_cursor.fetchone()
if table_comment and table_comment[0]:
table_info += f" COMMENT='{table_comment[0]}'"
except Exception:
pass # Ignore comment retrieval errors
# Add sample rows if configured
if self._sample_rows_in_table_info > 0:
table_info += self._get_sample_rows_for_doris(table_name)
# Add index information if configured
if self._indexes_in_table_info:
table_info += self._get_indexes_info_for_doris(table_name)
return table_info
except Exception as e:
return f"-- Error getting info for table {table_name}: {str(e)}"
def _get_sample_rows_for_doris(self, table_name: str) -> str:
"""Get sample rows for Doris table."""
try:
with self.session_scope() as session:
cursor = session.execute(
text(
f"SELECT * FROM {table_name} LIMIT "
f"{self._sample_rows_in_table_info}"
)
)
rows = cursor.fetchall()
if not rows:
return ""
# Get column names
column_names = list(cursor.keys())
columns_str = "\t".join(column_names)
# Format sample rows
sample_rows_str = "\n".join(
[
"\t".join(
[
str(val)[:100] if val is not None else "NULL"
for val in row
]
)
for row in rows
]
)
return (
f"\n\n/*\n{self._sample_rows_in_table_info} rows from "
f"{table_name} table:\n{columns_str}\n{sample_rows_str}\n*/"
)
except Exception:
return f"\n\n/*\nError getting sample rows for table {table_name}\n*/"
def _get_indexes_info_for_doris(self, table_name: str) -> str:
"""Get index information for Doris table."""
try:
indexes = self.get_indexes(table_name)
if not indexes:
return f"\n\n/*\nTable Indexes for {table_name}:\nNo indexes found\n*/"
indexes_str = "\n".join(
[f"Index: {idx[0]}, Column: {idx[1]}" for idx in indexes]
)
return f"\n\n/*\nTable Indexes for {table_name}:\n{indexes_str}\n*/"
except Exception:
return f"\n\n/*\nError getting indexes for table {table_name}\n*/"

View File

@ -1,13 +1,76 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_doris.py
docker run -it -d --name doris -p 8030:8030 -p 8040:8040 -p 9030:9030 -p 8048:8048 apache/doris:doris-all-in-one-2.1.0
9030: The MySQL protocol port of FE.
Connection: mysql -uadmin -P9030 -h127.0.0.1
"""
import pytest
from dbgpt.datasource.rdbms.conn_doris import DorisConnector
from dbgpt_ext.datasource.rdbms.conn_doris import DorisConnector
_create_table_sql = """
CREATE TABLE IF NOT EXISTS `test` (
`id` int(11) DEFAULT NULL,
`name` VARCHAR(200) DEFAULT NULL,
`sex` VARCHAR(200) DEFAULT NULL,
INDEX idx_name (`name`) USING INVERTED
) UNIQUE KEY(`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 10
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""
@pytest.fixture
def db():
conn = DorisConnector.from_uri_db("localhost", 9030, "root", "", "test")
conn = DorisConnector.from_uri_db("localhost", 9030, "admin", "", "test")
yield conn
def test_get_usable_table_names(db):
db.run(_create_table_sql)
print(db._sync_tables_from_db())
assert list(db.get_usable_table_names()) == ['test']
def test_get_table_info(db):
db.run(_create_table_sql)
print(db._sync_tables_from_db())
assert "CREATE TABLE test" in db.get_table_info()
def test_run_no_throw(db):
assert db.run_no_throw("this is a error sql") == []
def test_get_index(db):
db.run(_create_table_sql)
assert db.get_indexes("test") == [('idx_name', 'name')]
def test_get_fields(db):
db.run(_create_table_sql)
assert list(db.get_fields("test")[0])[0] == "id"
def test_get_charset(db):
assert db.get_charset() == "utf8mb4"
def test_get_collation(db):
assert (
db.get_collation() == "utf8mb4_0900_bin"
or db.get_collation() == "utf8mb4_general_ci"
)
def test_get_users(db):
assert db.get_users() == []
def test_get_database_lists(db):
assert "test" in db.get_database_names()