mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-30 23:56:25 +00:00
feat(datasource):add oracle datasource (#2629)
Co-authored-by: luobin <luobin@wondersgroup.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -47,6 +47,10 @@ datasource_mysql = [
|
||||
# libpq-dev and libmysqlclient-dev first.
|
||||
"mysqlclient==2.1.0",
|
||||
]
|
||||
datasource_oracle = [
|
||||
"oracledb==3.1.0", # use python-oracledb,new driver for Oracle
|
||||
]
|
||||
|
||||
datasource_postgres = [
|
||||
# "psycopg2", # In production, you can install psycopg2 instead of psycopg2-binary
|
||||
"psycopg2-binary",
|
||||
|
204
packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py
Normal file
204
packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Oracle connector using python-oracledb."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
ResourceCategory,
|
||||
auto_register_resource,
|
||||
)
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector, RDBMSDatasourceParameters
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
|
||||
@auto_register_resource(
|
||||
label=_("Oracle datasource"),
|
||||
category=ResourceCategory.DATABASE,
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
description=_(
|
||||
"Enterprise-grade relational database with oracledb driver (python-oracledb)."
|
||||
),
|
||||
)
|
||||
@dataclass
|
||||
class OracleParameters(RDBMSDatasourceParameters):
|
||||
"""Oracle connection parameters."""
|
||||
|
||||
__type__ = "oracle"
|
||||
|
||||
driver: str = field(
|
||||
default="oracle+oracledb", # ✅ 使用 python-oracledb 驱动
|
||||
metadata={
|
||||
"help": _("Driver name for Oracle, default is oracle+oracledb."),
|
||||
},
|
||||
)
|
||||
|
||||
service_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("Oracle service name (alternative to SID)."),
|
||||
},
|
||||
)
|
||||
|
||||
sid: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("Oracle SID (System ID, alternative to service name)."),
|
||||
},
|
||||
)
|
||||
|
||||
def db_url(self, ssl: bool = False, charset: Optional[str] = None) -> str:
|
||||
if self.service_name:
|
||||
dsn = (
|
||||
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})"
|
||||
f"(PORT={self.port}))(CONNECT_DATA=(SERVICE_NAME={self.service_name})))"
|
||||
)
|
||||
elif self.sid:
|
||||
dsn = (
|
||||
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})"
|
||||
f"(PORT={self.port}))(CONNECT_DATA=(SID={self.sid})))"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either service_name or sid must be provided for Oracle.")
|
||||
|
||||
return f"{self.driver}://{self.user}:{self.password}@{dsn}"
|
||||
|
||||
def create_connector(self) -> "OracleConnector":
|
||||
return OracleConnector.from_parameters(self)
|
||||
|
||||
|
||||
class OracleConnector(RDBMSConnector):
|
||||
db_type: str = "oracle"
|
||||
db_dialect: str = "oracle"
|
||||
driver: str = "oracle+oracledb"
|
||||
|
||||
@classmethod
|
||||
def param_class(cls) -> Type[RDBMSDatasourceParameters]:
|
||||
return OracleParameters
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
sid: Optional[str] = None,
|
||||
service_name: Optional[str] = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> "OracleConnector":
|
||||
if not sid and not service_name:
|
||||
raise ValueError("Must provide either sid or service_name")
|
||||
|
||||
if service_name:
|
||||
dsn = (
|
||||
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)"
|
||||
f"(HOST={host})(PORT={port}))(CONNECT_DATA=(SERVICE_NAME={service_name})))"
|
||||
)
|
||||
else:
|
||||
dsn = (
|
||||
f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={host})"
|
||||
f"(PORT={port}))(CONNECT_DATA=(SID={sid})))"
|
||||
)
|
||||
|
||||
bm_pwd = quote_plus(pwd)
|
||||
db_url = f"{cls.driver}://{user}:{bm_pwd}@{dsn}"
|
||||
|
||||
return cls.from_uri(db_url, engine_args=engine_args, **kwargs)
|
||||
|
||||
def get_fields(self, table_name: str, db_name=None) -> List[Tuple]:
|
||||
with self.session_scope() as session:
|
||||
query = f"""
|
||||
SELECT col.column_name,
|
||||
col.data_type,
|
||||
col.data_default,
|
||||
col.nullable,
|
||||
comm.comments
|
||||
FROM user_tab_columns col
|
||||
LEFT JOIN user_col_comments comm
|
||||
ON col.table_name = comm.table_name
|
||||
AND col.column_name = comm.column_name
|
||||
WHERE col.table_name = '{table_name.upper()}'
|
||||
"""
|
||||
result = session.execute(text(query))
|
||||
return result.fetchall()
|
||||
|
||||
def get_charset(self) -> str:
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT VALUE FROM NLS_DATABASE_PARAMETERS "
|
||||
"WHERE PARAMETER = 'NLS_CHARACTERSET'"
|
||||
)
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def get_grants(self):
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(text("SELECT privilege FROM user_sys_privs"))
|
||||
return cursor.fetchall()
|
||||
|
||||
def get_users(self) -> List[Tuple[str, None]]:
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(text("SELECT username FROM all_users"))
|
||||
return [(row[0], None) for row in cursor.fetchall()]
|
||||
|
||||
def get_database_names(self) -> List[str]:
|
||||
with self.session_scope() as session:
|
||||
is_cdb = session.execute(text("SELECT CDB FROM V$DATABASE")).fetchone()[0]
|
||||
if is_cdb == "YES":
|
||||
pdbs = session.execute(
|
||||
text("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'")
|
||||
).fetchall()
|
||||
return [name[0] for name in pdbs]
|
||||
else:
|
||||
return [
|
||||
session.execute(
|
||||
text("SELECT sys_context('USERENV', 'CON_NAME') FROM dual")
|
||||
).fetchone()[0]
|
||||
]
|
||||
|
||||
def get_table_comments(self, db_name: str) -> List[Tuple[str, str]]:
|
||||
with self.session_scope() as session:
|
||||
result = session.execute(
|
||||
text("SELECT table_name, comments FROM user_tab_comments")
|
||||
)
|
||||
return [(row[0], row[1]) for row in result.fetchall()]
|
||||
|
||||
def get_table_comment(self, table_name: str) -> Dict:
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT comments FROM user_tab_comments "
|
||||
f"WHERE table_name = '{table_name.upper()}'"
|
||||
)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return {"text": row[0] if row else ""}
|
||||
|
||||
def get_column_comments(
|
||||
self, db_name: str, table_name: str
|
||||
) -> List[Tuple[str, str]]:
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text(f"""
|
||||
SELECT column_name, comments
|
||||
FROM user_col_comments
|
||||
WHERE table_name = '{table_name.upper()}'
|
||||
""")
|
||||
)
|
||||
return [(row[0], row[1]) for row in cursor.fetchall()]
|
||||
|
||||
def get_collation(self) -> str:
|
||||
with self.session_scope() as session:
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT value FROM NLS_DATABASE_PARAMETERS "
|
||||
"WHERE parameter = 'NLS_SORT'"
|
||||
)
|
||||
)
|
||||
return cursor.fetchone()[0]
|
@@ -1,5 +1,6 @@
|
||||
"""Connection manager."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type
|
||||
|
||||
@@ -58,6 +59,9 @@ class ConnectorManager(BaseComponent):
|
||||
from dbgpt_ext.datasource.rdbms.conn_oceanbase import ( # noqa: F401
|
||||
OceanBaseConnector,
|
||||
)
|
||||
|
||||
# 添加OracleConnector导入
|
||||
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector # noqa: F401
|
||||
from dbgpt_ext.datasource.rdbms.conn_postgresql import ( # noqa: F401
|
||||
PostgreSQLConnector,
|
||||
)
|
||||
@@ -175,6 +179,19 @@ class ConnectorManager(BaseComponent):
|
||||
if db_type.is_file_db():
|
||||
db_path = db_config.get("db_path")
|
||||
return connect_instance.from_file_path(db_path) # type: ignore
|
||||
elif db_type.value() == "oracle":
|
||||
logger.info("-------------Oracle Datasource------------")
|
||||
host = db_config.get("db_host")
|
||||
port = db_config.get("db_port")
|
||||
user = db_config.get("db_user")
|
||||
pwd = db_config.get("db_pwd")
|
||||
extConfig = db_config.get("ext_config")
|
||||
dbJson = json.loads(extConfig)
|
||||
service_name = dbJson.get("service_name", None)
|
||||
sid = (dbJson.get("sid", None),)
|
||||
return connect_instance.from_uri_db( # type: ignore
|
||||
host=host, port=port, user=user, pwd=pwd, service_name=service_name
|
||||
)
|
||||
else:
|
||||
db_host = db_config.get("db_host")
|
||||
db_port = db_config.get("db_port")
|
||||
|
101
tests/intetration_tests/datasource/test_conn_oracle.py
Normal file
101
tests/intetration_tests/datasource/test_conn_oracle.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
|
||||
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
|
||||
mysql -h 127.0.0.1 -uroot -p -P3307
|
||||
Enter password:
|
||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||
Your MySQL connection id is 2
|
||||
Server version: 5.7.41 MySQL Community Server (GPL)
|
||||
|
||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||
|
||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||
affiliates. Other names may be trademarks of their respective
|
||||
owners.
|
||||
|
||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||
|
||||
> create database test;
|
||||
"""
|
||||
import pytest
|
||||
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector
|
||||
|
||||
|
||||
_create_table_sql = """
|
||||
CREATE TABLE test (
|
||||
id NUMBER(11) NULL
|
||||
)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
# 注意:Oracle 默认端口是 1521,连接方式建议用 service_name
|
||||
conn = OracleConnector.from_uri_db(
|
||||
host="localhost",
|
||||
port=1521,
|
||||
user="oracle_user",
|
||||
pwd="********",
|
||||
service_name="ORCL", # 替换为你的 service_name 或 SID
|
||||
)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
try:
|
||||
conn.run("DROP TABLE test PURGE")
|
||||
except Exception:
|
||||
pass # 如果表不存在也忽略错误
|
||||
|
||||
def test_get_usable_table_names(db):
|
||||
db.run(_create_table_sql)
|
||||
db.run("COMMIT")
|
||||
table_names = db.get_usable_table_names()
|
||||
assert "TEST" in map(str.upper, table_names)
|
||||
|
||||
def test_get_table_info(db):
|
||||
db.run(_create_table_sql)
|
||||
db.run("COMMIT")
|
||||
table_info = db.get_table_info()
|
||||
assert "CREATE TABLE TEST" in table_info.upper()
|
||||
|
||||
def test_run_no_throw(db):
|
||||
result = db.run_no_throw("this is a error sql")
|
||||
# run_no_throw 返回的是 list,错误时为空
|
||||
assert result == [] or isinstance(result, list)
|
||||
|
||||
def test_get_index_empty(db):
|
||||
db.run(_create_table_sql)
|
||||
db.run("COMMIT")
|
||||
indexes = db.get_indexes("TEST")
|
||||
assert indexes == []
|
||||
|
||||
def test_get_fields(db):
|
||||
#db.run(_create_table_sql)
|
||||
#db.run("COMMIT")
|
||||
print("进入方法...")
|
||||
fields = db.get_fields("PY_TEST")
|
||||
print("正在打印字段信息...")
|
||||
for field in fields:
|
||||
print(f"Column Name: {field[0]}")
|
||||
print(f"Data Type: {field[1]}")
|
||||
print(f"Default Value: {field[2]}")
|
||||
print(f"Is Nullable: {field[3]}")
|
||||
print(f"Column Comment: {field[4]}")
|
||||
print("-" * 30) # 可选的分隔符
|
||||
#assert fields[0][0].upper() == "ID"
|
||||
|
||||
def test_get_charset(db):
|
||||
result = db.run("SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'")
|
||||
assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组
|
||||
|
||||
def test_get_users(db):
|
||||
users = db.get_users()
|
||||
assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users)
|
||||
|
||||
def test_get_database_lists(db):
|
||||
cdb_result = db.run("SELECT CDB FROM V$DATABASE")
|
||||
if cdb_result[1][0] == "YES":
|
||||
databases = db.run("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'")
|
||||
pdb_names = [name[0] for name in databases[1:]]
|
||||
else:
|
||||
pdb_names = ["ORCL"]
|
||||
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)
|
Reference in New Issue
Block a user