diff --git a/packages/dbgpt-ext/pyproject.toml b/packages/dbgpt-ext/pyproject.toml index 7890e5ca0..37a775b40 100644 --- a/packages/dbgpt-ext/pyproject.toml +++ b/packages/dbgpt-ext/pyproject.toml @@ -47,6 +47,10 @@ datasource_mysql = [ # libpq-dev and libmysqlclient-dev first. "mysqlclient==2.1.0", ] +datasource_oracle = [ + "oracledb==3.1.0", # use python-oracledb,new driver for Oracle +] + datasource_postgres = [ # "psycopg2", # In production, you can install psycopg2 instead of psycopg2-binary "psycopg2-binary", diff --git a/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py b/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py new file mode 100644 index 000000000..9a047a1fb --- /dev/null +++ b/packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py @@ -0,0 +1,204 @@ +"""Oracle connector using python-oracledb.""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type +from urllib.parse import quote_plus + +from sqlalchemy import text + +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + ResourceCategory, + auto_register_resource, +) +from dbgpt.datasource.rdbms.base import RDBMSConnector, RDBMSDatasourceParameters +from dbgpt.util.i18n_utils import _ + + +@auto_register_resource( + label=_("Oracle datasource"), + category=ResourceCategory.DATABASE, + tags={"order": TAGS_ORDER_HIGH}, + description=_( + "Enterprise-grade relational database with oracledb driver (python-oracledb)." + ), +) +@dataclass +class OracleParameters(RDBMSDatasourceParameters): + """Oracle connection parameters.""" + + __type__ = "oracle" + + driver: str = field( + default="oracle+oracledb", # ✅ 使用 python-oracledb 驱动 + metadata={ + "help": _("Driver name for Oracle, default is oracle+oracledb."), + }, + ) + + service_name: Optional[str] = field( + default=None, + metadata={ + "help": _("Oracle service name (alternative to SID)."), + }, + ) + + sid: Optional[str] = field( + default=None, + metadata={ + "help": _("Oracle SID (System ID, alternative to service name)."), + }, + ) + + def db_url(self, ssl: bool = False, charset: Optional[str] = None) -> str: + if self.service_name: + dsn = ( + f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})" + f"(PORT={self.port}))(CONNECT_DATA=(SERVICE_NAME={self.service_name})))" + ) + elif self.sid: + dsn = ( + f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={self.host})" + f"(PORT={self.port}))(CONNECT_DATA=(SID={self.sid})))" + ) + else: + raise ValueError("Either service_name or sid must be provided for Oracle.") + + return f"{self.driver}://{self.user}:{self.password}@{dsn}" + + def create_connector(self) -> "OracleConnector": + return OracleConnector.from_parameters(self) + + +class OracleConnector(RDBMSConnector): + db_type: str = "oracle" + db_dialect: str = "oracle" + driver: str = "oracle+oracledb" + + @classmethod + def param_class(cls) -> Type[RDBMSDatasourceParameters]: + return OracleParameters + + @classmethod + def from_uri_db( + cls, + host: str, + port: int, + user: str, + pwd: str, + sid: Optional[str] = None, + service_name: Optional[str] = None, + engine_args: Optional[dict] = None, + **kwargs, + ) -> "OracleConnector": + if not sid and not service_name: + raise ValueError("Must provide either sid or service_name") + + if service_name: + dsn = ( + f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)" + f"(HOST={host})(PORT={port}))(CONNECT_DATA=(SERVICE_NAME={service_name})))" + ) + else: + dsn = ( + f"(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={host})" + f"(PORT={port}))(CONNECT_DATA=(SID={sid})))" + ) + + bm_pwd = quote_plus(pwd) + db_url = f"{cls.driver}://{user}:{bm_pwd}@{dsn}" + + return cls.from_uri(db_url, engine_args=engine_args, **kwargs) + + def get_fields(self, table_name: str, db_name=None) -> List[Tuple]: + with self.session_scope() as session: + query = f""" + SELECT col.column_name, + col.data_type, + col.data_default, + col.nullable, + comm.comments + FROM user_tab_columns col + LEFT JOIN user_col_comments comm + ON col.table_name = comm.table_name + AND col.column_name = comm.column_name + WHERE col.table_name = '{table_name.upper()}' + """ + result = session.execute(text(query)) + return result.fetchall() + + def get_charset(self) -> str: + with self.session_scope() as session: + cursor = session.execute( + text( + "SELECT VALUE FROM NLS_DATABASE_PARAMETERS " + "WHERE PARAMETER = 'NLS_CHARACTERSET'" + ) + ) + return cursor.fetchone()[0] + + def get_grants(self): + with self.session_scope() as session: + cursor = session.execute(text("SELECT privilege FROM user_sys_privs")) + return cursor.fetchall() + + def get_users(self) -> List[Tuple[str, None]]: + with self.session_scope() as session: + cursor = session.execute(text("SELECT username FROM all_users")) + return [(row[0], None) for row in cursor.fetchall()] + + def get_database_names(self) -> List[str]: + with self.session_scope() as session: + is_cdb = session.execute(text("SELECT CDB FROM V$DATABASE")).fetchone()[0] + if is_cdb == "YES": + pdbs = session.execute( + text("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'") + ).fetchall() + return [name[0] for name in pdbs] + else: + return [ + session.execute( + text("SELECT sys_context('USERENV', 'CON_NAME') FROM dual") + ).fetchone()[0] + ] + + def get_table_comments(self, db_name: str) -> List[Tuple[str, str]]: + with self.session_scope() as session: + result = session.execute( + text("SELECT table_name, comments FROM user_tab_comments") + ) + return [(row[0], row[1]) for row in result.fetchall()] + + def get_table_comment(self, table_name: str) -> Dict: + with self.session_scope() as session: + cursor = session.execute( + text( + f"SELECT comments FROM user_tab_comments " + f"WHERE table_name = '{table_name.upper()}'" + ) + ) + row = cursor.fetchone() + return {"text": row[0] if row else ""} + + def get_column_comments( + self, db_name: str, table_name: str + ) -> List[Tuple[str, str]]: + with self.session_scope() as session: + cursor = session.execute( + text(f""" + SELECT column_name, comments + FROM user_col_comments + WHERE table_name = '{table_name.upper()}' + """) + ) + return [(row[0], row[1]) for row in cursor.fetchall()] + + def get_collation(self) -> str: + with self.session_scope() as session: + cursor = session.execute( + text( + "SELECT value FROM NLS_DATABASE_PARAMETERS " + "WHERE parameter = 'NLS_SORT'" + ) + ) + return cursor.fetchone()[0] diff --git a/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connector_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connector_manager.py index d853d1212..367941a4a 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connector_manager.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/datasource/manages/connector_manager.py @@ -1,5 +1,6 @@ """Connection manager.""" +import json import logging from typing import TYPE_CHECKING, Dict, List, Optional, Type @@ -58,6 +59,9 @@ class ConnectorManager(BaseComponent): from dbgpt_ext.datasource.rdbms.conn_oceanbase import ( # noqa: F401 OceanBaseConnector, ) + + # 添加OracleConnector导入 + from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector # noqa: F401 from dbgpt_ext.datasource.rdbms.conn_postgresql import ( # noqa: F401 PostgreSQLConnector, ) @@ -175,6 +179,19 @@ class ConnectorManager(BaseComponent): if db_type.is_file_db(): db_path = db_config.get("db_path") return connect_instance.from_file_path(db_path) # type: ignore + elif db_type.value() == "oracle": + logger.info("-------------Oracle Datasource------------") + host = db_config.get("db_host") + port = db_config.get("db_port") + user = db_config.get("db_user") + pwd = db_config.get("db_pwd") + extConfig = db_config.get("ext_config") + dbJson = json.loads(extConfig) + service_name = dbJson.get("service_name", None) + sid = (dbJson.get("sid", None),) + return connect_instance.from_uri_db( # type: ignore + host=host, port=port, user=user, pwd=pwd, service_name=service_name + ) else: db_host = db_config.get("db_host") db_port = db_config.get("db_port") diff --git a/tests/intetration_tests/datasource/test_conn_oracle.py b/tests/intetration_tests/datasource/test_conn_oracle.py new file mode 100644 index 000000000..5cc734e09 --- /dev/null +++ b/tests/intetration_tests/datasource/test_conn_oracle.py @@ -0,0 +1,101 @@ +""" + Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py + docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7 + mysql -h 127.0.0.1 -uroot -p -P3307 + Enter password: + Welcome to the MySQL monitor. Commands end with ; or \g. + Your MySQL connection id is 2 + Server version: 5.7.41 MySQL Community Server (GPL) + + Copyright (c) 2000, 2023, Oracle and/or its affiliates. + + Oracle is a registered trademark of Oracle Corporation and/or its + affiliates. Other names may be trademarks of their respective + owners. + + Type 'help;' or '\h' for help. Type '\c' to clear the current input statement. + + > create database test; +""" +import pytest +from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector + + +_create_table_sql = """ +CREATE TABLE test ( + id NUMBER(11) NULL +) +""" + +@pytest.fixture +def db(): + # 注意:Oracle 默认端口是 1521,连接方式建议用 service_name + conn = OracleConnector.from_uri_db( + host="localhost", + port=1521, + user="oracle_user", + pwd="********", + service_name="ORCL", # 替换为你的 service_name 或 SID + ) + try: + yield conn + finally: + try: + conn.run("DROP TABLE test PURGE") + except Exception: + pass # 如果表不存在也忽略错误 + +def test_get_usable_table_names(db): + db.run(_create_table_sql) + db.run("COMMIT") + table_names = db.get_usable_table_names() + assert "TEST" in map(str.upper, table_names) + +def test_get_table_info(db): + db.run(_create_table_sql) + db.run("COMMIT") + table_info = db.get_table_info() + assert "CREATE TABLE TEST" in table_info.upper() + +def test_run_no_throw(db): + result = db.run_no_throw("this is a error sql") + # run_no_throw 返回的是 list,错误时为空 + assert result == [] or isinstance(result, list) + +def test_get_index_empty(db): + db.run(_create_table_sql) + db.run("COMMIT") + indexes = db.get_indexes("TEST") + assert indexes == [] + +def test_get_fields(db): + #db.run(_create_table_sql) + #db.run("COMMIT") + print("进入方法...") + fields = db.get_fields("PY_TEST") + print("正在打印字段信息...") + for field in fields: + print(f"Column Name: {field[0]}") + print(f"Data Type: {field[1]}") + print(f"Default Value: {field[2]}") + print(f"Is Nullable: {field[3]}") + print(f"Column Comment: {field[4]}") + print("-" * 30) # 可选的分隔符 + #assert fields[0][0].upper() == "ID" + +def test_get_charset(db): + result = db.run("SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'") + assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组 + +def test_get_users(db): + users = db.get_users() + assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users) + +def test_get_database_lists(db): + cdb_result = db.run("SELECT CDB FROM V$DATABASE") + if cdb_result[1][0] == "YES": + databases = db.run("SELECT NAME FROM V$PDBS WHERE OPEN_MODE = 'READ WRITE'") + pdb_names = [name[0] for name in databases[1:]] + else: + pdb_names = ["ORCL"] + assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names) \ No newline at end of file