mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-18 08:17:38 +00:00
fix(datasource): fix doris DB connection use mysql protocol (#2875)
This commit is contained in:
parent
4e7070b6ee
commit
21896becd9
@ -1,11 +1,13 @@
|
||||
"""Doris connector."""
|
||||
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import MetaData, inspect, text
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
@ -32,9 +34,11 @@ class DorisParameters(RDBMSDatasourceParameters):
|
||||
|
||||
__type__ = "doris"
|
||||
driver: str = field(
|
||||
default="doris",
|
||||
default="mysql+pymysql",
|
||||
metadata={
|
||||
"help": _("Driver name for Doris, default is doris."),
|
||||
"help": _(
|
||||
"Driver name for Doris, default is mysql+pymysql (MySQL compatible)."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@ -46,9 +50,56 @@ class DorisParameters(RDBMSDatasourceParameters):
|
||||
class DorisConnector(RDBMSConnector):
|
||||
"""Doris connector."""
|
||||
|
||||
driver = "doris"
|
||||
driver = "mysql+pymysql"
|
||||
db_type = "doris"
|
||||
db_dialect = "doris"
|
||||
db_dialect = "mysql"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[Dict[str, str]] = None,
|
||||
view_support: bool = False,
|
||||
):
|
||||
"""Initialize Doris connector without triggering reflection.
|
||||
|
||||
Override parent __init__ to avoid automatic metadata.reflect() call
|
||||
which causes issues with Doris data type parsing.
|
||||
"""
|
||||
# Initialize basic attributes (copied from parent but without reflect)
|
||||
self._is_closed = False
|
||||
self._engine = engine
|
||||
self._schema = schema
|
||||
if include_tables and ignore_tables:
|
||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||
|
||||
if not custom_table_info:
|
||||
custom_table_info = {}
|
||||
|
||||
self._inspector = inspect(engine)
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
Session_Manages = scoped_session(session_factory)
|
||||
self._db_sessions = Session_Manages
|
||||
self._sessions = weakref.WeakSet()
|
||||
|
||||
self.view_support = view_support
|
||||
self._usable_tables = set()
|
||||
self._include_tables = set()
|
||||
self._ignore_tables = set()
|
||||
self._custom_table_info = custom_table_info
|
||||
self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
self._indexes_in_table_info = indexes_in_table_info
|
||||
|
||||
# NOT call reflect() to avoid Doris type parsing issues
|
||||
# self._metadata = metadata or MetaData()
|
||||
# self._metadata.reflect(bind=self._engine)
|
||||
|
||||
self._all_tables = set(self._sync_tables_from_db())
|
||||
|
||||
@classmethod
|
||||
def param_class(cls) -> Type[DorisParameters]:
|
||||
@ -83,7 +134,6 @@ class DorisConnector(RDBMSConnector):
|
||||
)
|
||||
table_results = set(row[0] for row in table_results) # noqa: C401
|
||||
self._all_tables = table_results
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def get_grants(self):
|
||||
@ -158,9 +208,22 @@ class DorisConnector(RDBMSConnector):
|
||||
(field[0], field[1], field[2], field[3], field[4]) for field in fields
|
||||
]
|
||||
|
||||
def get_charset(self):
|
||||
def get_charset(self) -> str:
|
||||
"""Get character_set."""
|
||||
return "utf-8"
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT DEFAULT_CHARACTER_SET_NAME
|
||||
FROM information_schema.SCHEMATA
|
||||
where SCHEMA_NAME=database()
|
||||
"""
|
||||
)
|
||||
)
|
||||
ans = cursor.fetchall()
|
||||
if ans:
|
||||
return ans[0][0]
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name) -> str:
|
||||
"""Get show create table."""
|
||||
@ -260,3 +323,156 @@ class DorisConnector(RDBMSConnector):
|
||||
cursor = session.execute(text(f"SHOW INDEX FROM {table_name}"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[2], index[4]) for index in indexes]
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables.
|
||||
|
||||
Override parent method to avoid dependency on metadata.reflect()
|
||||
which causes issues with Doris data type parsing.
|
||||
Uses direct SQL queries to get table information.
|
||||
"""
|
||||
all_table_names = list(self.get_usable_table_names())
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
|
||||
if not all_table_names:
|
||||
return ""
|
||||
|
||||
tables = []
|
||||
for table_name in all_table_names:
|
||||
if self._custom_table_info and table_name in self._custom_table_info:
|
||||
tables.append(self._custom_table_info[table_name])
|
||||
continue
|
||||
|
||||
# Build table info using direct SQL queries
|
||||
table_info = self._build_table_info_for_doris(table_name)
|
||||
tables.append(table_info)
|
||||
|
||||
return "\n\n".join(tables)
|
||||
|
||||
def _build_table_info_for_doris(self, table_name: str) -> str:
|
||||
"""Build table information for Doris using direct SQL queries."""
|
||||
try:
|
||||
with self.session_scope() as session:
|
||||
# Get table structure information
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE, "
|
||||
"COLUMN_DEFAULT, COLUMN_COMMENT "
|
||||
"FROM information_schema.columns "
|
||||
f'WHERE TABLE_NAME="{table_name}" AND TABLE_SCHEMA=database() '
|
||||
"ORDER BY ORDINAL_POSITION"
|
||||
)
|
||||
)
|
||||
columns = cursor.fetchall()
|
||||
|
||||
if not columns:
|
||||
return f"-- Table {table_name} not found"
|
||||
|
||||
# Build CREATE TABLE statement
|
||||
table_info = f"CREATE TABLE {table_name} (\n"
|
||||
column_definitions = []
|
||||
|
||||
for col in columns:
|
||||
col_name, col_type, is_nullable, col_default, col_comment = col
|
||||
col_def = f" `{col_name}` {col_type}"
|
||||
|
||||
if is_nullable == "NO":
|
||||
col_def += " NOT NULL"
|
||||
|
||||
if col_default is not None:
|
||||
col_def += f" DEFAULT {col_default}"
|
||||
|
||||
if col_comment:
|
||||
col_def += f" COMMENT '{col_comment}'"
|
||||
|
||||
column_definitions.append(col_def)
|
||||
|
||||
table_info += ",\n".join(column_definitions)
|
||||
table_info += "\n)"
|
||||
|
||||
# Get table comment if available
|
||||
try:
|
||||
comment_cursor = session.execute(
|
||||
text(
|
||||
"SELECT TABLE_COMMENT FROM information_schema.tables "
|
||||
f'WHERE TABLE_NAME="{table_name}"'
|
||||
f" AND TABLE_SCHEMA=database()"
|
||||
)
|
||||
)
|
||||
table_comment = comment_cursor.fetchone()
|
||||
if table_comment and table_comment[0]:
|
||||
table_info += f" COMMENT='{table_comment[0]}'"
|
||||
except Exception:
|
||||
pass # Ignore comment retrieval errors
|
||||
|
||||
# Add sample rows if configured
|
||||
if self._sample_rows_in_table_info > 0:
|
||||
table_info += self._get_sample_rows_for_doris(table_name)
|
||||
|
||||
# Add index information if configured
|
||||
if self._indexes_in_table_info:
|
||||
table_info += self._get_indexes_info_for_doris(table_name)
|
||||
|
||||
return table_info
|
||||
|
||||
except Exception as e:
|
||||
return f"-- Error getting info for table {table_name}: {str(e)}"
|
||||
|
||||
def _get_sample_rows_for_doris(self, table_name: str) -> str:
|
||||
"""Get sample rows for Doris table."""
|
||||
try:
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT * FROM {table_name} LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
# Get column names
|
||||
column_names = list(cursor.keys())
|
||||
columns_str = "\t".join(column_names)
|
||||
|
||||
# Format sample rows
|
||||
sample_rows_str = "\n".join(
|
||||
[
|
||||
"\t".join(
|
||||
[
|
||||
str(val)[:100] if val is not None else "NULL"
|
||||
for val in row
|
||||
]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
)
|
||||
|
||||
return (
|
||||
f"\n\n/*\n{self._sample_rows_in_table_info} rows from "
|
||||
f"{table_name} table:\n{columns_str}\n{sample_rows_str}\n*/"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return f"\n\n/*\nError getting sample rows for table {table_name}\n*/"
|
||||
|
||||
def _get_indexes_info_for_doris(self, table_name: str) -> str:
|
||||
"""Get index information for Doris table."""
|
||||
try:
|
||||
indexes = self.get_indexes(table_name)
|
||||
if not indexes:
|
||||
return f"\n\n/*\nTable Indexes for {table_name}:\nNo indexes found\n*/"
|
||||
|
||||
indexes_str = "\n".join(
|
||||
[f"Index: {idx[0]}, Column: {idx[1]}" for idx in indexes]
|
||||
)
|
||||
return f"\n\n/*\nTable Indexes for {table_name}:\n{indexes_str}\n*/"
|
||||
|
||||
except Exception:
|
||||
return f"\n\n/*\nError getting indexes for table {table_name}\n*/"
|
||||
|
@ -1,13 +1,76 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_doris.py
|
||||
|
||||
docker run -it -d --name doris -p 8030:8030 -p 8040:8040 -p 9030:9030 -p 8048:8048 apache/doris:doris-all-in-one-2.1.0
|
||||
|
||||
9030: The MySQL protocol port of FE.
|
||||
|
||||
Connection: mysql -uadmin -P9030 -h127.0.0.1
|
||||
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_doris import DorisConnector
|
||||
from dbgpt_ext.datasource.rdbms.conn_doris import DorisConnector
|
||||
|
||||
_create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS `test` (
|
||||
`id` int(11) DEFAULT NULL,
|
||||
`name` VARCHAR(200) DEFAULT NULL,
|
||||
`sex` VARCHAR(200) DEFAULT NULL,
|
||||
INDEX idx_name (`name`) USING INVERTED
|
||||
) UNIQUE KEY(`id`)
|
||||
DISTRIBUTED BY HASH(`id`) BUCKETS 10
|
||||
PROPERTIES (
|
||||
"replication_allocation" = "tag.location.default: 1"
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
conn = DorisConnector.from_uri_db("localhost", 9030, "root", "", "test")
|
||||
conn = DorisConnector.from_uri_db("localhost", 9030, "admin", "", "test")
|
||||
yield conn
|
||||
|
||||
|
||||
def test_get_usable_table_names(db):
|
||||
db.run(_create_table_sql)
|
||||
print(db._sync_tables_from_db())
|
||||
assert list(db.get_usable_table_names()) == ['test']
|
||||
|
||||
|
||||
def test_get_table_info(db):
|
||||
db.run(_create_table_sql)
|
||||
print(db._sync_tables_from_db())
|
||||
assert "CREATE TABLE test" in db.get_table_info()
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
assert db.run_no_throw("this is a error sql") == []
|
||||
|
||||
|
||||
def test_get_index(db):
|
||||
db.run(_create_table_sql)
|
||||
assert db.get_indexes("test") == [('idx_name', 'name')]
|
||||
|
||||
|
||||
def test_get_fields(db):
|
||||
db.run(_create_table_sql)
|
||||
assert list(db.get_fields("test")[0])[0] == "id"
|
||||
|
||||
|
||||
def test_get_charset(db):
|
||||
assert db.get_charset() == "utf8mb4"
|
||||
|
||||
|
||||
def test_get_collation(db):
|
||||
assert (
|
||||
db.get_collation() == "utf8mb4_0900_bin"
|
||||
or db.get_collation() == "utf8mb4_general_ci"
|
||||
)
|
||||
|
||||
def test_get_users(db):
|
||||
assert db.get_users() == []
|
||||
|
||||
def test_get_database_lists(db):
|
||||
assert "test" in db.get_database_names()
|
||||
|
Loading…
Reference in New Issue
Block a user