mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
feat(datasource):add oracle datasource (#2629)
Co-authored-by: luobin <luobin@wondersgroup.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -47,6 +47,10 @@ datasource_mysql = [
|
|||||||
# libpq-dev and libmysqlclient-dev first.
|
# libpq-dev and libmysqlclient-dev first.
|
||||||
"mysqlclient==2.1.0",
|
"mysqlclient==2.1.0",
|
||||||
]
|
]
|
||||||
|
datasource_oracle = [
|
||||||
|
"oracledb==3.1.0", # use python-oracledb,new driver for Oracle
|
||||||
|
]
|
||||||
|
|
||||||
datasource_postgres = [
|
datasource_postgres = [
|
||||||
# "psycopg2", # In production, you can install psycopg2 instead of psycopg2-binary
|
# "psycopg2", # In production, you can install psycopg2 instead of psycopg2-binary
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
|
204
packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py
Normal file
204
packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""Oracle connector using python-oracledb."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from dbgpt.core.awel.flow import (
|
||||||
|
TAGS_ORDER_HIGH,
|
||||||
|
ResourceCategory,
|
||||||
|
auto_register_resource,
|
||||||
|
)
|
||||||
|
from dbgpt.datasource.rdbms.base import RDBMSConnector, RDBMSDatasourceParameters
|
||||||
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
|
|
||||||
|
@auto_register_resource(
|
||||||
|
label=_("Oracle datasource"),
|
||||||
|
category=ResourceCategory.DATABASE,
|
||||||
|
tags={"order": TAGS_ORDER_HIGH},
|
||||||
|
description=_(
|
||||||
|
"Enterprise-grade relational database with oracledb driver (python-oracledb)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@dataclass
|
||||||
|
class OracleParameters(RDBMSDatasourceParameters):
|
||||||
|
"""Oracle connection parameters."""
|
||||||
|
|
||||||
|
__type__ = "oracle"
|
||||||
|
|
||||||
|
driver: str = field(
|
||||||
|
default="oracle+oracledb", # ✅ 使用 python-oracledb 驱动
|
||||||
|
metadata={
|
||||||
|
"help": _("Driver name for Oracle, default is oracle+oracledb."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
service_name: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _("Oracle service name (alternative to SID)."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
sid: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _("Oracle SID (System ID, alternative to service name)."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def db_url(self, ssl: bool = False, charset: Optional[str] = None) -> str:
|
||||||
|
if self.service_name:
|
||||||
|
dsn = (
|
||||||
|
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})"
|
||||||
|
f"(PORT={self.port}))(CONNECT_DATA=(SERVICE_NAME={self.service_name})))"
|
||||||
|
)
|
||||||
|
elif self.sid:
|
||||||
|
dsn = (
|
||||||
|
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})"
|
||||||
|
f"(PORT={self.port}))(CONNECT_DATA=(SID={self.sid})))"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either service_name or sid must be provided for Oracle.")
|
||||||
|
|
||||||
|
return f"{self.driver}://{self.user}:{self.password}@{dsn}"
|
||||||
|
|
||||||
|
def create_connector(self) -> "OracleConnector":
|
||||||
|
return OracleConnector.from_parameters(self)
|
||||||
|
|
||||||
|
|
||||||
|
class OracleConnector(RDBMSConnector):
|
||||||
|
db_type: str = "oracle"
|
||||||
|
db_dialect: str = "oracle"
|
||||||
|
driver: str = "oracle+oracledb"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def param_class(cls) -> Type[RDBMSDatasourceParameters]:
|
||||||
|
return OracleParameters
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_uri_db(
|
||||||
|
cls,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
user: str,
|
||||||
|
pwd: str,
|
||||||
|
sid: Optional[str] = None,
|
||||||
|
service_name: Optional[str] = None,
|
||||||
|
engine_args: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "OracleConnector":
|
||||||
|
if not sid and not service_name:
|
||||||
|
raise ValueError("Must provide either sid or service_name")
|
||||||
|
|
||||||
|
if service_name:
|
||||||
|
dsn = (
|
||||||
|
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)"
|
||||||
|
f"(HOST={host})(PORT={port}))(CONNECT_DATA=(SERVICE_NAME={service_name})))"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dsn = (
|
||||||
|
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={host})"
|
||||||
|
f"(PORT={port}))(CONNECT_DATA=(SID={sid})))"
|
||||||
|
)
|
||||||
|
|
||||||
|
bm_pwd = quote_plus(pwd)
|
||||||
|
db_url = f"{cls.driver}://{user}:{bm_pwd}@{dsn}"
|
||||||
|
|
||||||
|
return cls.from_uri(db_url, engine_args=engine_args, **kwargs)
|
||||||
|
|
||||||
|
def get_fields(self, table_name: str, db_name=None) -> List[Tuple]:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
query = f"""
|
||||||
|
SELECT col.column_name,
|
||||||
|
col.data_type,
|
||||||
|
col.data_default,
|
||||||
|
col.nullable,
|
||||||
|
comm.comments
|
||||||
|
FROM user_tab_columns col
|
||||||
|
LEFT JOIN user_col_comments comm
|
||||||
|
ON col.table_name = comm.table_name
|
||||||
|
AND col.column_name = comm.column_name
|
||||||
|
WHERE col.table_name = '{table_name.upper()}'
|
||||||
|
"""
|
||||||
|
result = session.execute(text(query))
|
||||||
|
return result.fetchall()
|
||||||
|
|
||||||
|
def get_charset(self) -> str:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
"SELECT VALUE FROM NLS_DATABASE_PARAMETERS "
|
||||||
|
"WHERE PARAMETER = 'NLS_CHARACTERSET'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cursor.fetchone()[0]
|
||||||
|
|
||||||
|
def get_grants(self):
|
||||||
|
with self.session_scope() as session:
|
||||||
|
cursor = session.execute(text("SELECT privilege FROM user_sys_privs"))
|
||||||
|
return cursor.fetchall()
|
||||||
|
|
||||||
|
def get_users(self) -> List[Tuple[str, None]]:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
cursor = session.execute(text("SELECT username FROM all_users"))
|
||||||
|
return [(row[0], None) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def get_database_names(self) -> List[str]:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
is_cdb = session.execute(text("SELECT CDB FROM V$DATABASE")).fetchone()[0]
|
||||||
|
if is_cdb == "YES":
|
||||||
|
pdbs = session.execute(
|
||||||
|
text("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'")
|
||||||
|
).fetchall()
|
||||||
|
return [name[0] for name in pdbs]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
session.execute(
|
||||||
|
text("SELECT sys_context('USERENV', 'CON_NAME') FROM dual")
|
||||||
|
).fetchone()[0]
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_table_comments(self, db_name: str) -> List[Tuple[str, str]]:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
result = session.execute(
|
||||||
|
text("SELECT table_name, comments FROM user_tab_comments")
|
||||||
|
)
|
||||||
|
return [(row[0], row[1]) for row in result.fetchall()]
|
||||||
|
|
||||||
|
def get_table_comment(self, table_name: str) -> Dict:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
f"SELECT comments FROM user_tab_comments "
|
||||||
|
f"WHERE table_name = '{table_name.upper()}'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return {"text": row[0] if row else ""}
|
||||||
|
|
||||||
|
def get_column_comments(
|
||||||
|
self, db_name: str, table_name: str
|
||||||
|
) -> List[Tuple[str, str]]:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
cursor = session.execute(
|
||||||
|
text(f"""
|
||||||
|
SELECT column_name, comments
|
||||||
|
FROM user_col_comments
|
||||||
|
WHERE table_name = '{table_name.upper()}'
|
||||||
|
""")
|
||||||
|
)
|
||||||
|
return [(row[0], row[1]) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def get_collation(self) -> str:
|
||||||
|
with self.session_scope() as session:
|
||||||
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
"SELECT value FROM NLS_DATABASE_PARAMETERS "
|
||||||
|
"WHERE parameter = 'NLS_SORT'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cursor.fetchone()[0]
|
@@ -1,5 +1,6 @@
|
|||||||
"""Connection manager."""
|
"""Connection manager."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
||||||
|
|
||||||
@@ -58,6 +59,9 @@ class ConnectorManager(BaseComponent):
|
|||||||
from dbgpt_ext.datasource.rdbms.conn_oceanbase import ( # noqa: F401
|
from dbgpt_ext.datasource.rdbms.conn_oceanbase import ( # noqa: F401
|
||||||
OceanBaseConnector,
|
OceanBaseConnector,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 添加OracleConnector导入
|
||||||
|
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector # noqa: F401
|
||||||
from dbgpt_ext.datasource.rdbms.conn_postgresql import ( # noqa: F401
|
from dbgpt_ext.datasource.rdbms.conn_postgresql import ( # noqa: F401
|
||||||
PostgreSQLConnector,
|
PostgreSQLConnector,
|
||||||
)
|
)
|
||||||
@@ -175,6 +179,19 @@ class ConnectorManager(BaseComponent):
|
|||||||
if db_type.is_file_db():
|
if db_type.is_file_db():
|
||||||
db_path = db_config.get("db_path")
|
db_path = db_config.get("db_path")
|
||||||
return connect_instance.from_file_path(db_path) # type: ignore
|
return connect_instance.from_file_path(db_path) # type: ignore
|
||||||
|
elif db_type.value() == "oracle":
|
||||||
|
logger.info("-------------Oracle Datasource------------")
|
||||||
|
host = db_config.get("db_host")
|
||||||
|
port = db_config.get("db_port")
|
||||||
|
user = db_config.get("db_user")
|
||||||
|
pwd = db_config.get("db_pwd")
|
||||||
|
extConfig = db_config.get("ext_config")
|
||||||
|
dbJson = json.loads(extConfig)
|
||||||
|
service_name = dbJson.get("service_name", None)
|
||||||
|
sid = (dbJson.get("sid", None),)
|
||||||
|
return connect_instance.from_uri_db( # type: ignore
|
||||||
|
host=host, port=port, user=user, pwd=pwd, service_name=service_name
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
db_host = db_config.get("db_host")
|
db_host = db_config.get("db_host")
|
||||||
db_port = db_config.get("db_port")
|
db_port = db_config.get("db_port")
|
||||||
|
101
tests/intetration_tests/datasource/test_conn_oracle.py
Normal file
101
tests/intetration_tests/datasource/test_conn_oracle.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
|
||||||
|
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
|
||||||
|
mysql -h 127.0.0.1 -uroot -p -P3307
|
||||||
|
Enter password:
|
||||||
|
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||||
|
Your MySQL connection id is 2
|
||||||
|
Server version: 5.7.41 MySQL Community Server (GPL)
|
||||||
|
|
||||||
|
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||||
|
|
||||||
|
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||||
|
affiliates. Other names may be trademarks of their respective
|
||||||
|
owners.
|
||||||
|
|
||||||
|
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||||
|
|
||||||
|
> create database test;
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector
|
||||||
|
|
||||||
|
|
||||||
|
_create_table_sql = """
|
||||||
|
CREATE TABLE test (
|
||||||
|
id NUMBER(11) NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db():
|
||||||
|
# 注意:Oracle 默认端口是 1521,连接方式建议用 service_name
|
||||||
|
conn = OracleConnector.from_uri_db(
|
||||||
|
host="localhost",
|
||||||
|
port=1521,
|
||||||
|
user="oracle_user",
|
||||||
|
pwd="********",
|
||||||
|
service_name="ORCL", # 替换为你的 service_name 或 SID
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.run("DROP TABLE test PURGE")
|
||||||
|
except Exception:
|
||||||
|
pass # 如果表不存在也忽略错误
|
||||||
|
|
||||||
|
def test_get_usable_table_names(db):
|
||||||
|
db.run(_create_table_sql)
|
||||||
|
db.run("COMMIT")
|
||||||
|
table_names = db.get_usable_table_names()
|
||||||
|
assert "TEST" in map(str.upper, table_names)
|
||||||
|
|
||||||
|
def test_get_table_info(db):
|
||||||
|
db.run(_create_table_sql)
|
||||||
|
db.run("COMMIT")
|
||||||
|
table_info = db.get_table_info()
|
||||||
|
assert "CREATE TABLE TEST" in table_info.upper()
|
||||||
|
|
||||||
|
def test_run_no_throw(db):
|
||||||
|
result = db.run_no_throw("this is a error sql")
|
||||||
|
# run_no_throw 返回的是 list,错误时为空
|
||||||
|
assert result == [] or isinstance(result, list)
|
||||||
|
|
||||||
|
def test_get_index_empty(db):
|
||||||
|
db.run(_create_table_sql)
|
||||||
|
db.run("COMMIT")
|
||||||
|
indexes = db.get_indexes("TEST")
|
||||||
|
assert indexes == []
|
||||||
|
|
||||||
|
def test_get_fields(db):
|
||||||
|
#db.run(_create_table_sql)
|
||||||
|
#db.run("COMMIT")
|
||||||
|
print("进入方法...")
|
||||||
|
fields = db.get_fields("PY_TEST")
|
||||||
|
print("正在打印字段信息...")
|
||||||
|
for field in fields:
|
||||||
|
print(f"Column Name: {field[0]}")
|
||||||
|
print(f"Data Type: {field[1]}")
|
||||||
|
print(f"Default Value: {field[2]}")
|
||||||
|
print(f"Is Nullable: {field[3]}")
|
||||||
|
print(f"Column Comment: {field[4]}")
|
||||||
|
print("-" * 30) # 可选的分隔符
|
||||||
|
#assert fields[0][0].upper() == "ID"
|
||||||
|
|
||||||
|
def test_get_charset(db):
|
||||||
|
result = db.run("SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'")
|
||||||
|
assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组
|
||||||
|
|
||||||
|
def test_get_users(db):
|
||||||
|
users = db.get_users()
|
||||||
|
assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users)
|
||||||
|
|
||||||
|
def test_get_database_lists(db):
|
||||||
|
cdb_result = db.run("SELECT CDB FROM V$DATABASE")
|
||||||
|
if cdb_result[1][0] == "YES":
|
||||||
|
databases = db.run("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'")
|
||||||
|
pdb_names = [name[0] for name in databases[1:]]
|
||||||
|
else:
|
||||||
|
pdb_names = ["ORCL"]
|
||||||
|
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)
|
Reference in New Issue
Block a user