feat(datasource):add oracle datasource (#2629)

Co-authored-by: luobin <luobin@wondersgroup.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
luowenwu
2025-04-27 16:21:11 +08:00
committed by GitHub
parent 430235bd1b
commit 1b77ed6319
4 changed files with 326 additions and 0 deletions

View File

@@ -47,6 +47,10 @@ datasource_mysql = [
# libpq-dev and libmysqlclient-dev first.
"mysqlclient==2.1.0",
]
datasource_oracle = [
"oracledb==3.1.0", # use python-oracledbnew driver for Oracle
]
datasource_postgres = [
# "psycopg2", # In production, you can install psycopg2 instead of psycopg2-binary
"psycopg2-binary",

View File

@@ -0,0 +1,204 @@
"""Oracle connector using python-oracledb."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type
from urllib.parse import quote_plus
from sqlalchemy import text
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
ResourceCategory,
auto_register_resource,
)
from dbgpt.datasource.rdbms.base import RDBMSConnector, RDBMSDatasourceParameters
from dbgpt.util.i18n_utils import _
@auto_register_resource(
label=_("Oracle datasource"),
category=ResourceCategory.DATABASE,
tags={"order": TAGS_ORDER_HIGH},
description=_(
"Enterprise-grade relational database with oracledb driver (python-oracledb)."
),
)
@dataclass
class OracleParameters(RDBMSDatasourceParameters):
"""Oracle connection parameters."""
__type__ = "oracle"
driver: str = field(
default="oracle+oracledb", # ✅ 使用 python-oracledb 驱动
metadata={
"help": _("Driver name for Oracle, default is oracle+oracledb."),
},
)
service_name: Optional[str] = field(
default=None,
metadata={
"help": _("Oracle service name (alternative to SID)."),
},
)
sid: Optional[str] = field(
default=None,
metadata={
"help": _("Oracle SID (System ID, alternative to service name)."),
},
)
def db_url(self, ssl: bool = False, charset: Optional[str] = None) -> str:
if self.service_name:
dsn = (
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})"
f"(PORT={self.port}))(CONNECT_DATA=(SERVICE_NAME={self.service_name})))"
)
elif self.sid:
dsn = (
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})"
f"(PORT={self.port}))(CONNECT_DATA=(SID={self.sid})))"
)
else:
raise ValueError("Either service_name or sid must be provided for Oracle.")
return f"{self.driver}://{self.user}:{self.password}@{dsn}"
def create_connector(self) -> "OracleConnector":
return OracleConnector.from_parameters(self)
class OracleConnector(RDBMSConnector):
db_type: str = "oracle"
db_dialect: str = "oracle"
driver: str = "oracle+oracledb"
@classmethod
def param_class(cls) -> Type[RDBMSDatasourceParameters]:
return OracleParameters
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
sid: Optional[str] = None,
service_name: Optional[str] = None,
engine_args: Optional[dict] = None,
**kwargs,
) -> "OracleConnector":
if not sid and not service_name:
raise ValueError("Must provide either sid or service_name")
if service_name:
dsn = (
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)"
f"(HOST={host})(PORT={port}))(CONNECT_DATA=(SERVICE_NAME={service_name})))"
)
else:
dsn = (
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={host})"
f"(PORT={port}))(CONNECT_DATA=(SID={sid})))"
)
bm_pwd = quote_plus(pwd)
db_url = f"{cls.driver}://{user}:{bm_pwd}@{dsn}"
return cls.from_uri(db_url, engine_args=engine_args, **kwargs)
def get_fields(self, table_name: str, db_name=None) -> List[Tuple]:
with self.session_scope() as session:
query = f"""
SELECT col.column_name,
col.data_type,
col.data_default,
col.nullable,
comm.comments
FROM user_tab_columns col
LEFT JOIN user_col_comments comm
ON col.table_name = comm.table_name
AND col.column_name = comm.column_name
WHERE col.table_name = '{table_name.upper()}'
"""
result = session.execute(text(query))
return result.fetchall()
def get_charset(self) -> str:
with self.session_scope() as session:
cursor = session.execute(
text(
"SELECT VALUE FROM NLS_DATABASE_PARAMETERS "
"WHERE PARAMETER = 'NLS_CHARACTERSET'"
)
)
return cursor.fetchone()[0]
def get_grants(self):
with self.session_scope() as session:
cursor = session.execute(text("SELECT privilege FROM user_sys_privs"))
return cursor.fetchall()
def get_users(self) -> List[Tuple[str, None]]:
with self.session_scope() as session:
cursor = session.execute(text("SELECT username FROM all_users"))
return [(row[0], None) for row in cursor.fetchall()]
def get_database_names(self) -> List[str]:
with self.session_scope() as session:
is_cdb = session.execute(text("SELECT CDB FROM V$DATABASE")).fetchone()[0]
if is_cdb == "YES":
pdbs = session.execute(
text("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'")
).fetchall()
return [name[0] for name in pdbs]
else:
return [
session.execute(
text("SELECT sys_context('USERENV', 'CON_NAME') FROM dual")
).fetchone()[0]
]
def get_table_comments(self, db_name: str) -> List[Tuple[str, str]]:
with self.session_scope() as session:
result = session.execute(
text("SELECT table_name, comments FROM user_tab_comments")
)
return [(row[0], row[1]) for row in result.fetchall()]
def get_table_comment(self, table_name: str) -> Dict:
with self.session_scope() as session:
cursor = session.execute(
text(
f"SELECT comments FROM user_tab_comments "
f"WHERE table_name = '{table_name.upper()}'"
)
)
row = cursor.fetchone()
return {"text": row[0] if row else ""}
def get_column_comments(
self, db_name: str, table_name: str
) -> List[Tuple[str, str]]:
with self.session_scope() as session:
cursor = session.execute(
text(f"""
SELECT column_name, comments
FROM user_col_comments
WHERE table_name = '{table_name.upper()}'
""")
)
return [(row[0], row[1]) for row in cursor.fetchall()]
def get_collation(self) -> str:
with self.session_scope() as session:
cursor = session.execute(
text(
"SELECT value FROM NLS_DATABASE_PARAMETERS "
"WHERE parameter = 'NLS_SORT'"
)
)
return cursor.fetchone()[0]

View File

@@ -1,5 +1,6 @@
"""Connection manager."""
import json
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Type
@@ -58,6 +59,9 @@ class ConnectorManager(BaseComponent):
from dbgpt_ext.datasource.rdbms.conn_oceanbase import ( # noqa: F401
OceanBaseConnector,
)
# 添加OracleConnector导入
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector # noqa: F401
from dbgpt_ext.datasource.rdbms.conn_postgresql import ( # noqa: F401
PostgreSQLConnector,
)
@@ -175,6 +179,19 @@ class ConnectorManager(BaseComponent):
if db_type.is_file_db():
db_path = db_config.get("db_path")
return connect_instance.from_file_path(db_path) # type: ignore
elif db_type.value() == "oracle":
logger.info("-------------Oracle Datasource------------")
host = db_config.get("db_host")
port = db_config.get("db_port")
user = db_config.get("db_user")
pwd = db_config.get("db_pwd")
extConfig = db_config.get("ext_config")
dbJson = json.loads(extConfig)
service_name = dbJson.get("service_name", None)
sid = (dbJson.get("sid", None),)
return connect_instance.from_uri_db( # type: ignore
host=host, port=port, user=user, pwd=pwd, service_name=service_name
)
else:
db_host = db_config.get("db_host")
db_port = db_config.get("db_port")

View File

@@ -0,0 +1,101 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
mysql -h 127.0.0.1 -uroot -p -P3307
Enter password:
Welcome to the MySQL monitor. Commands end with ; or \g.
Your MySQL connection id is 2
Server version: 5.7.41 MySQL Community Server (GPL)
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
Oracle is a registered trademark of Oracle Corporation and/or its
affiliates. Other names may be trademarks of their respective
owners.
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
> create database test;
"""
import pytest
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector
_create_table_sql = """
CREATE TABLE test (
id NUMBER(11) NULL
)
"""
@pytest.fixture
def db():
# 注意Oracle 默认端口是 1521连接方式建议用 service_name
conn = OracleConnector.from_uri_db(
host="localhost",
port=1521,
user="oracle_user",
pwd="********",
service_name="ORCL", # 替换为你的 service_name 或 SID
)
try:
yield conn
finally:
try:
conn.run("DROP TABLE test PURGE")
except Exception:
pass # 如果表不存在也忽略错误
def test_get_usable_table_names(db):
db.run(_create_table_sql)
db.run("COMMIT")
table_names = db.get_usable_table_names()
assert "TEST" in map(str.upper, table_names)
def test_get_table_info(db):
db.run(_create_table_sql)
db.run("COMMIT")
table_info = db.get_table_info()
assert "CREATE TABLE TEST" in table_info.upper()
def test_run_no_throw(db):
result = db.run_no_throw("this is a error sql")
# run_no_throw 返回的是 list错误时为空
assert result == [] or isinstance(result, list)
def test_get_index_empty(db):
db.run(_create_table_sql)
db.run("COMMIT")
indexes = db.get_indexes("TEST")
assert indexes == []
def test_get_fields(db):
#db.run(_create_table_sql)
#db.run("COMMIT")
print("进入方法...")
fields = db.get_fields("PY_TEST")
print("正在打印字段信息...")
for field in fields:
print(f"Column Name: {field[0]}")
print(f"Data Type: {field[1]}")
print(f"Default Value: {field[2]}")
print(f"Is Nullable: {field[3]}")
print(f"Column Comment: {field[4]}")
print("-" * 30) # 可选的分隔符
#assert fields[0][0].upper() == "ID"
def test_get_charset(db):
result = db.run("SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'")
assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组
def test_get_users(db):
users = db.get_users()
assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users)
def test_get_database_lists(db):
cdb_result = db.run("SELECT CDB FROM V$DATABASE")
if cdb_result[1][0] == "YES":
databases = db.run("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'")
pdb_names = [name[0] for name in databases[1:]]
else:
pdb_names = ["ORCL"]
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)