diff --git a/docs/docs/integrations/document_loaders/oracleadb_loader.ipynb b/docs/docs/integrations/document_loaders/oracleadb_loader.ipynb new file mode 100644 index 00000000000..63b23c1e920 --- /dev/null +++ b/docs/docs/integrations/document_loaders/oracleadb_loader.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Oracle Autonomous Database\n", + "\n", + "This notebook covers how to load documents from oracle autonomous database, the loader supports connection with connection string or tns config.\n", + "\n", + "## Prerequisites\n", + "1. Database runs in a 'Thin' mode.\n", + " https://python-oracledb.readthedocs.io/en/latest/user_guide/appendix_b.html\n", + "2. `pip install oracledb`\n", + " https://python-oracledb.readthedocs.io/en/latest/user_guide/installation.html" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "pip install oracledb" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "is_executing": true + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from langchain_community.document_loaders import OracleAutonomousDatabaseLoader\n", + "from settings import s" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "is_executing": true + } + } + }, + { + "cell_type": "markdown", + "source": [ + "With mutual TLS authentication (mTLS), wallet_location and wallet_password are required to create the connection, user can create connection by providing either connection string or tns config details." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "SQL_QUERY = \"select prod_id, time_id from sh.costs fetch first 5 rows only\"\n", + "\n", + "doc_loader_1 = OracleAutonomousDatabaseLoader(\n", + " user=s.USERNAME,\n", + " password=s.PASSWORD,\n", + " schema=s.SCHEMA,\n", + " config_dir=s.CONFIG_DIR,\n", + " wallet_location=s.WALLET_LOCATION,\n", + " wallet_password=s.PASSWORD,\n", + " tns_name=s.TNS_NAME,\n", + " query=SQL_QUERY,\n", + ")\n", + "doc_1 = doc_loader_1.load()\n", + "\n", + "doc_loader_2 = OracleAutonomousDatabaseLoader(\n", + " user=s.USERNAME,\n", + " password=s.PASSWORD,\n", + " schema=s.SCHEMA,\n", + " connection_string=s.CONNECTION_STRING,\n", + " wallet_location=s.WALLET_LOCATION,\n", + " wallet_password=s.PASSWORD,\n", + " query=SQL_QUERY,\n", + ")\n", + "doc_2 = doc_loader_2.load()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "is_executing": true + } + } + }, + { + "cell_type": "markdown", + "source": [ + "With TLS authentication, wallet_location and wallet_password are not required." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "doc_loader_3 = OracleAutonomousDatabaseLoader(\n", + " user=s.USERNAME,\n", + " password=s.PASSWORD,\n", + " schema=s.SCHEMA,\n", + " config_dir=s.CONFIG_DIR,\n", + " tns_name=s.TNS_NAME,\n", + " query=SQL_QUERY,\n", + ")\n", + "doc_3 = doc_loader_3.load()\n", + "\n", + "doc_loader_4 = OracleAutonomousDatabaseLoader(\n", + " user=s.USERNAME,\n", + " password=s.PASSWORD,\n", + " schema=s.SCHEMA,\n", + " connection_string=s.CONNECTION_STRING,\n", + " query=SQL_QUERY,\n", + ")\n", + "doc_4 = doc_loader_4.load()" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py index 973981c1111..927ae41819a 100644 --- a/libs/community/langchain_community/document_loaders/__init__.py +++ b/libs/community/langchain_community/document_loaders/__init__.py @@ -121,6 +121,7 @@ _module_lookup = { "OneDriveLoader": "langchain_community.document_loaders.onedrive", "OnlinePDFLoader": "langchain_community.document_loaders.pdf", "OpenCityDataLoader": "langchain_community.document_loaders.open_city_data", + "OracleAutonomousDatabaseLoader": "langchain_community.document_loaders.oracleadb_loader", # noqa: E501 "OutlookMessageLoader": "langchain_community.document_loaders.email", "PDFMinerLoader": "langchain_community.document_loaders.pdf", "PDFMinerPDFasHTMLLoader": "langchain_community.document_loaders.pdf", diff --git a/libs/community/langchain_community/document_loaders/oracleadb_loader.py b/libs/community/langchain_community/document_loaders/oracleadb_loader.py new file mode 100644 index 00000000000..90a8d608855 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/oracleadb_loader.py @@ -0,0 +1,124 @@ +from typing import Any, Dict, List, Optional + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseLoader + + +class OracleAutonomousDatabaseLoader(BaseLoader): + """ + Load from oracle adb + + Autonomous Database connection can be made by either connection_string + or tns name. wallet_location and wallet_password are required + for TLS connection. + Each document will represent one row of the query result. + Columns are written into the `page_content` and 'metadata' in + constructor is written into 'metadata' of document, + by default, the 'metadata' is None. + """ + + def __init__( + self, + query: str, + user: str, + password: str, + *, + schema: Optional[str] = None, + tns_name: Optional[str] = None, + config_dir: Optional[str] = None, + wallet_location: Optional[str] = None, + wallet_password: Optional[str] = None, + connection_string: Optional[str] = None, + metadata: Optional[List[str]] = None, + ): + """ + init method + :param query: sql query to execute + :param user: username + :param password: user password + :param schema: schema to run in database + :param tns_name: tns name in tnsname.ora + :param config_dir: directory of config files(tnsname.ora, wallet) + :param wallet_location: location of wallet + :param wallet_password: password of wallet + :param connection_string: connection string to connect to adb instance + :param metadata: metadata used in document + """ + # Mandatory required arguments. + self.query = query + self.user = user + self.password = password + + # Schema + self.schema = schema + + # TNS connection Method + self.tns_name = tns_name + self.config_dir = config_dir + + # Wallet configuration is required for mTLS connection + self.wallet_location = wallet_location + self.wallet_password = wallet_password + + # Connection String connection method + self.connection_string = connection_string + + # metadata column + self.metadata = metadata + + # dsn + self.dsn: Optional[str] + self._set_dsn() + + def _set_dsn(self) -> None: + if self.connection_string: + self.dsn = self.connection_string + elif self.tns_name: + self.dsn = self.tns_name + + def _run_query(self) -> List[Dict[str, Any]]: + try: + import oracledb + except ImportError as e: + raise ImportError( + "Could not import oracledb, " + "please install with 'pip install oracledb'" + ) from e + connect_param = {"user": self.user, "password": self.password, "dsn": self.dsn} + if self.dsn == self.tns_name: + connect_param["config_dir"] = self.config_dir + if self.wallet_location and self.wallet_password: + connect_param["wallet_location"] = self.wallet_location + connect_param["wallet_password"] = self.wallet_password + + try: + connection = oracledb.connect(**connect_param) + cursor = connection.cursor() + if self.schema: + cursor.execute(f"alter session set current_schema={self.schema}") + cursor.execute(self.query) + columns = [col[0] for col in cursor.description] + data = cursor.fetchall() + data = [dict(zip(columns, row)) for row in data] + except oracledb.DatabaseError as e: + print("Got error while connecting: " + str(e)) # noqa: T201 + data = [] + finally: + cursor.close() + connection.close() + + return data + + def load(self) -> List[Document]: + data = self._run_query() + documents = [] + metadata_columns = self.metadata if self.metadata else [] + for row in data: + metadata = { + key: value for key, value in row.items() if key in metadata_columns + } + doc = Document(page_content=str(row), metadata=metadata) + documents.append(doc) + + return documents diff --git a/libs/community/tests/unit_tests/document_loaders/test_imports.py b/libs/community/tests/unit_tests/document_loaders/test_imports.py index ad4da720162..fdfff2bfacf 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/community/tests/unit_tests/document_loaders/test_imports.py @@ -107,6 +107,7 @@ EXPECTED_ALL = [ "OneDriveLoader", "OnlinePDFLoader", "OpenCityDataLoader", + "OracleAutonomousDatabaseLoader", "OutlookMessageLoader", "PDFMinerLoader", "PDFMinerPDFasHTMLLoader", diff --git a/libs/community/tests/unit_tests/document_loaders/test_oracleadb.py b/libs/community/tests/unit_tests/document_loaders/test_oracleadb.py new file mode 100644 index 00000000000..4ec5597bb9f --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/test_oracleadb.py @@ -0,0 +1,56 @@ +from typing import Dict, List +from unittest.mock import MagicMock, patch + +from langchain_core.documents import Document + +from langchain_community.document_loaders.oracleadb_loader import ( + OracleAutonomousDatabaseLoader, +) + + +def raw_docs() -> List[Dict]: + return [ + {"FIELD1": "1", "FIELD_JSON": {"INNER_FIELD1": "1", "INNER_FIELD2": "1"}}, + {"FIELD1": "2", "FIELD_JSON": {"INNER_FIELD1": "2", "INNER_FIELD2": "2"}}, + {"FIELD1": "3", "FIELD_JSON": {"INNER_FIELD1": "3", "INNER_FIELD2": "3"}}, + ] + + +def expected_documents() -> List[Document]: + return [ + Document( + page_content="{'FIELD1': '1', 'FIELD_JSON': " + "{'INNER_FIELD1': '1', 'INNER_FIELD2': '1'}}", + metadata={"FIELD1": "1"}, + ), + Document( + page_content="{'FIELD1': '2', 'FIELD_JSON': " + "{'INNER_FIELD1': '2', 'INNER_FIELD2': '2'}}", + metadata={"FIELD1": "2"}, + ), + Document( + page_content="{'FIELD1': '3', 'FIELD_JSON': " + "{'INNER_FIELD1': '3', 'INNER_FIELD2': '3'}}", + metadata={"FIELD1": "3"}, + ), + ] + + +@patch( + "langchain_community.document_loaders.oracleadb_loader.OracleAutonomousDatabaseLoader._run_query" +) +def test_oracle_loader_load(mock_query: MagicMock) -> None: + """Test oracleDB loader load function.""" + + mock_query.return_value = raw_docs() + loader = OracleAutonomousDatabaseLoader( + query="Test query", + user="Test user", + password="Test password", + connection_string="Test connection string", + metadata=["FIELD1"], + ) + + documents = loader.load() + + assert documents == expected_documents()