diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 33c873e60e0..c3600a8caf9 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -32,3 +32,26 @@ services: environment: MONGO_INITDB_ROOT_USERNAME: langchain MONGO_INITDB_ROOT_PASSWORD: langchain + postgres: + image: postgres:16 + environment: + POSTGRES_DB: langchain + POSTGRES_USER: langchain + POSTGRES_PASSWORD: langchain + ports: + - "6023:5432" + command: | + postgres -c log_statement=all + healthcheck: + test: + [ + "CMD-SHELL", + "psql postgresql://langchain:langchain@localhost/langchain --command 'SELECT 1;' || exit 1", + ] + interval: 5s + retries: 60 + volumes: + - postgres_data:/var/lib/postgresql/data + +volumes: + postgres_data: diff --git a/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql new file mode 100644 index 00000000000..33cb765a38e --- /dev/null +++ b/docs/docs/integrations/document_loaders/example_data/mlb_teams_2012.sql @@ -0,0 +1,40 @@ +-- Provisioning table "mlb_teams_2012". +-- +-- psql postgresql://postgres@localhost < mlb_teams_2012.sql + +DROP TABLE IF EXISTS mlb_teams_2012; +CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); +INSERT INTO mlb_teams_2012 + ("Team", "Payroll (millions)", "Wins") +VALUES + ('Nationals', 81.34, 98), + ('Reds', 82.20, 97), + ('Yankees', 197.96, 95), + ('Giants', 117.62, 94), + ('Braves', 83.31, 94), + ('Athletics', 55.37, 94), + ('Rangers', 120.51, 93), + ('Orioles', 81.43, 93), + ('Rays', 64.17, 90), + ('Angels', 154.49, 89), + ('Tigers', 132.30, 88), + ('Cardinals', 110.30, 88), + ('Dodgers', 95.14, 86), + ('White Sox', 96.92, 85), + ('Brewers', 97.65, 83), + ('Phillies', 174.54, 81), + ('Diamondbacks', 74.28, 81), + ('Pirates', 63.43, 79), + ('Padres', 55.24, 76), + ('Mariners', 81.97, 75), + ('Mets', 93.35, 74), + ('Blue Jays', 75.48, 73), + ('Royals', 60.91, 72), + ('Marlins', 118.07, 69), + ('Red Sox', 173.18, 69), + ('Indians', 78.43, 68), + ('Twins', 94.08, 66), + ('Rockies', 78.06, 64), + ('Cubs', 88.19, 61), + ('Astros', 60.65, 55) +; diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py index 10190996e38..4d708a89ad8 100644 --- a/libs/community/langchain_community/document_loaders/__init__.py +++ b/libs/community/langchain_community/document_loaders/__init__.py @@ -187,6 +187,7 @@ from langchain_community.document_loaders.sitemap import SitemapLoader from langchain_community.document_loaders.slack_directory import SlackDirectoryLoader from langchain_community.document_loaders.snowflake_loader import SnowflakeLoader from langchain_community.document_loaders.spreedly import SpreedlyLoader +from langchain_community.document_loaders.sql_database import SQLDatabaseLoader from langchain_community.document_loaders.srt import SRTLoader from langchain_community.document_loaders.stripe import StripeLoader from langchain_community.document_loaders.surrealdb import SurrealDBLoader @@ -376,6 +377,7 @@ __all__ = [ "SlackDirectoryLoader", "SnowflakeLoader", "SpreedlyLoader", + "SQLDatabaseLoader", "StripeLoader", "SurrealDBLoader", "TelegramChatApiLoader", diff --git a/libs/community/langchain_community/document_loaders/sql_database.py b/libs/community/langchain_community/document_loaders/sql_database.py new file mode 100644 index 00000000000..91c1a37684c --- /dev/null +++ b/libs/community/langchain_community/document_loaders/sql_database.py @@ -0,0 +1,139 @@ +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union + +import sqlalchemy as sa + +from langchain_community.docstore.document import Document +from langchain_community.document_loaders.base import BaseLoader +from langchain_community.utilities.sql_database import SQLDatabase + + +class SQLDatabaseLoader(BaseLoader): + """ + Load documents by querying database tables supported by SQLAlchemy. + + For talking to the database, the document loader uses the `SQLDatabase` + utility from the LangChain integration toolkit. + + Each document represents one row of the result. + """ + + def __init__( + self, + query: Union[str, sa.Select], + db: SQLDatabase, + *, + parameters: Optional[Dict[str, Any]] = None, + page_content_mapper: Optional[Callable[..., str]] = None, + metadata_mapper: Optional[Callable[..., Dict[str, Any]]] = None, + source_columns: Optional[Sequence[str]] = None, + include_rownum_into_metadata: bool = False, + include_query_into_metadata: bool = False, + ): + """ + Args: + query: The query to execute. + db: A LangChain `SQLDatabase`, wrapping an SQLAlchemy engine. + sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`. + parameters: Optional. Parameters to pass to the query. + page_content_mapper: Optional. Function to convert a row into a string + to use as the `page_content` of the document. By default, the loader + serializes the whole row into a string, including all columns. + metadata_mapper: Optional. Function to convert a row into a dictionary + to use as the `metadata` of the document. By default, no columns are + selected into the metadata dictionary. + source_columns: Optional. The names of the columns to use as the `source` + within the metadata dictionary. + include_rownum_into_metadata: Optional. Whether to include the row number + into the metadata dictionary. Default: False. + include_query_into_metadata: Optional. Whether to include the query + expression into the metadata dictionary. Default: False. + """ + self.query = query + self.db: SQLDatabase = db + self.parameters = parameters or {} + self.page_content_mapper = ( + page_content_mapper or self.page_content_default_mapper + ) + self.metadata_mapper = metadata_mapper or self.metadata_default_mapper + self.source_columns = source_columns + self.include_rownum_into_metadata = include_rownum_into_metadata + self.include_query_into_metadata = include_query_into_metadata + + def lazy_load(self) -> Iterator[Document]: + try: + import sqlalchemy as sa + except ImportError: + raise ImportError( + "Could not import sqlalchemy python package. " + "Please install it with `pip install sqlalchemy`." + ) + + # Querying in `cursor` fetch mode will return an SQLAlchemy `Result` instance. + result: sa.Result[Any] + + # Invoke the database query. + if isinstance(self.query, sa.SelectBase): + result = self.db._execute( # type: ignore[assignment] + self.query, fetch="cursor", parameters=self.parameters + ) + query_sql = str(self.query.compile(bind=self.db._engine)) + elif isinstance(self.query, str): + result = self.db._execute( # type: ignore[assignment] + sa.text(self.query), fetch="cursor", parameters=self.parameters + ) + query_sql = self.query + else: + raise TypeError(f"Unable to process query of unknown type: {self.query}") + + # Iterate database result rows and generate list of documents. + for i, row in enumerate(result.mappings()): + page_content = self.page_content_mapper(row) + metadata = self.metadata_mapper(row) + + if self.include_rownum_into_metadata: + metadata["row"] = i + if self.include_query_into_metadata: + metadata["query"] = query_sql + + source_values = [] + for column, value in row.items(): + if self.source_columns and column in self.source_columns: + source_values.append(value) + if source_values: + metadata["source"] = ",".join(source_values) + + yield Document(page_content=page_content, metadata=metadata) + + def load(self) -> List[Document]: + return list(self.lazy_load()) + + @staticmethod + def page_content_default_mapper( + row: sa.RowMapping, column_names: Optional[List[str]] = None + ) -> str: + """ + A reasonable default function to convert a record into a "page content" string. + """ + if column_names is None: + column_names = list(row.keys()) + return "\n".join( + f"{column}: {value}" + for column, value in row.items() + if column in column_names + ) + + @staticmethod + def metadata_default_mapper( + row: sa.RowMapping, column_names: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + A reasonable default function to convert a record into a "metadata" dictionary. + """ + if column_names is None: + return {} + + metadata: Dict[str, Any] = {} + for column, value in row.items(): + if column in column_names: + metadata[column] = value + return metadata diff --git a/libs/community/tests/data.py b/libs/community/tests/data.py new file mode 100644 index 00000000000..7ca12a8466f --- /dev/null +++ b/libs/community/tests/data.py @@ -0,0 +1,10 @@ +"""Module defines common test data.""" +from pathlib import Path + +_THIS_DIR = Path(__file__).parent + +_EXAMPLES_DIR = _THIS_DIR / "examples" + +# Paths to data files +MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv" +MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql" diff --git a/libs/community/tests/examples/mlb_teams_2012.csv b/libs/community/tests/examples/mlb_teams_2012.csv new file mode 100644 index 00000000000..b22ae961a13 --- /dev/null +++ b/libs/community/tests/examples/mlb_teams_2012.csv @@ -0,0 +1,32 @@ +"Team", "Payroll (millions)", "Wins" +"Nationals", 81.34, 98 +"Reds", 82.20, 97 +"Yankees", 197.96, 95 +"Giants", 117.62, 94 +"Braves", 83.31, 94 +"Athletics", 55.37, 94 +"Rangers", 120.51, 93 +"Orioles", 81.43, 93 +"Rays", 64.17, 90 +"Angels", 154.49, 89 +"Tigers", 132.30, 88 +"Cardinals", 110.30, 88 +"Dodgers", 95.14, 86 +"White Sox", 96.92, 85 +"Brewers", 97.65, 83 +"Phillies", 174.54, 81 +"Diamondbacks", 74.28, 81 +"Pirates", 63.43, 79 +"Padres", 55.24, 76 +"Mariners", 81.97, 75 +"Mets", 93.35, 74 +"Blue Jays", 75.48, 73 +"Royals", 60.91, 72 +"Marlins", 118.07, 69 +"Red Sox", 173.18, 69 +"Indians", 78.43, 68 +"Twins", 94.08, 66 +"Rockies", 78.06, 64 +"Cubs", 88.19, 61 +"Astros", 60.65, 55 + diff --git a/libs/community/tests/examples/mlb_teams_2012.sql b/libs/community/tests/examples/mlb_teams_2012.sql new file mode 100644 index 00000000000..33cb765a38e --- /dev/null +++ b/libs/community/tests/examples/mlb_teams_2012.sql @@ -0,0 +1,40 @@ +-- Provisioning table "mlb_teams_2012". +-- +-- psql postgresql://postgres@localhost < mlb_teams_2012.sql + +DROP TABLE IF EXISTS mlb_teams_2012; +CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT); +INSERT INTO mlb_teams_2012 + ("Team", "Payroll (millions)", "Wins") +VALUES + ('Nationals', 81.34, 98), + ('Reds', 82.20, 97), + ('Yankees', 197.96, 95), + ('Giants', 117.62, 94), + ('Braves', 83.31, 94), + ('Athletics', 55.37, 94), + ('Rangers', 120.51, 93), + ('Orioles', 81.43, 93), + ('Rays', 64.17, 90), + ('Angels', 154.49, 89), + ('Tigers', 132.30, 88), + ('Cardinals', 110.30, 88), + ('Dodgers', 95.14, 86), + ('White Sox', 96.92, 85), + ('Brewers', 97.65, 83), + ('Phillies', 174.54, 81), + ('Diamondbacks', 74.28, 81), + ('Pirates', 63.43, 79), + ('Padres', 55.24, 76), + ('Mariners', 81.97, 75), + ('Mets', 93.35, 74), + ('Blue Jays', 75.48, 73), + ('Royals', 60.91, 72), + ('Marlins', 118.07, 69), + ('Red Sox', 173.18, 69), + ('Indians', 78.43, 68), + ('Twins', 94.08, 66), + ('Rockies', 78.06, 64), + ('Cubs', 88.19, 61), + ('Astros', 60.65, 55) +; diff --git a/libs/community/tests/integration_tests/document_loaders/test_sql_database.py b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py new file mode 100644 index 00000000000..911fcd07676 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/test_sql_database.py @@ -0,0 +1,252 @@ +""" +Test SQLAlchemy document loader functionality on behalf of SQLite and PostgreSQL. + +To run the tests for SQLite, you need to have the `sqlite3` package installed. + +To run the tests for PostgreSQL, you need to have the `psycopg2` package installed. +In addition, to launch the PostgreSQL instance, you can use the docker compose file +located at the root of the repo, `langchain/docker/docker-compose.yml`. Use the +command `docker compose up postgres` to start the instance. It will have the +appropriate credentials set up including being exposed on the appropriate port. +""" +import functools +import logging +import typing +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory + +import pytest +import sqlalchemy as sa + +from langchain_community.utilities.sql_database import SQLDatabase + +if typing.TYPE_CHECKING: + from _pytest.python import Metafunc + +from langchain_community.document_loaders.sql_database import SQLDatabaseLoader +from tests.data import MLB_TEAMS_2012_SQL + +logging.basicConfig(level=logging.DEBUG) + + +try: + import sqlite3 # noqa: F401 + + sqlite3_installed = True +except ImportError: + warnings.warn("sqlite3 not installed, skipping corresponding tests", UserWarning) + sqlite3_installed = False + +try: + import psycopg2 # noqa: F401 + + psycopg2_installed = True +except ImportError: + warnings.warn("psycopg2 not installed, skipping corresponding tests", UserWarning) + psycopg2_installed = False + + +@pytest.fixture() +def engine(db_uri: str) -> sa.Engine: + """ + Return an SQLAlchemy engine object. + """ + return sa.create_engine(db_uri, echo=False) + + +@pytest.fixture() +def db(engine: sa.Engine) -> SQLDatabase: + return SQLDatabase(engine=engine) + + +@pytest.fixture() +def provision_database(engine: sa.Engine) -> None: + """ + Provision database with table schema and data. + """ + sql_statements = MLB_TEAMS_2012_SQL.read_text() + with engine.connect() as connection: + connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;")) + for statement in sql_statements.split(";"): + statement = statement.strip() + if not statement: + continue + connection.execute(sa.text(statement)) + connection.commit() + + +tmpdir = TemporaryDirectory() + + +def pytest_generate_tests(metafunc: "Metafunc") -> None: + """ + Dynamically parameterize test cases to verify both SQLite and PostgreSQL. + """ + if "db_uri" in metafunc.fixturenames: + urls = [] + ids = [] + if sqlite3_installed: + db_path = Path(tmpdir.name).joinpath("testdrive.sqlite") + urls.append(f"sqlite:///{db_path}") + ids.append("sqlite") + if psycopg2_installed: + # We use non-standard port for testing purposes. + # The easiest way to spin up the PostgreSQL instance is to use + # the docker compose file at the root of the repo located at + # langchain/docker/docker-compose.yml + # use `docker compose up postgres` to start the instance + # it will have the appropriate credentials set up including + # being exposed on the appropriate port. + urls.append( + "postgresql+psycopg2://langchain:langchain@localhost:6023/langchain" + ) + ids.append("postgresql") + + metafunc.parametrize("db_uri", urls, ids=ids) + + +def test_sqldatabase_loader_no_options(db: SQLDatabase) -> None: + """Test SQLAlchemy loader basics.""" + + loader = SQLDatabaseLoader("SELECT 1 AS a, 2 AS b", db=db) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {} + + +def test_sqldatabase_loader_include_rownum_into_metadata(db: SQLDatabase) -> None: + """Test SQLAlchemy loader with row number in metadata.""" + + loader = SQLDatabaseLoader( + "SELECT 1 AS a, 2 AS b", + db=db, + include_rownum_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"row": 0} + + +def test_sqldatabase_loader_include_query_into_metadata(db: SQLDatabase) -> None: + """Test SQLAlchemy loader with query in metadata.""" + + loader = SQLDatabaseLoader( + "SELECT 1 AS a, 2 AS b", db=db, include_query_into_metadata=True + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].page_content == "a: 1\nb: 2" + assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"} + + +def test_sqldatabase_loader_page_content_columns(db: SQLDatabase) -> None: + """Test SQLAlchemy loader with defined page content columns.""" + + # Define a custom callback function to convert a row into a "page content" string. + row_to_content = functools.partial( + SQLDatabaseLoader.page_content_default_mapper, column_names=["a"] + ) + + loader = SQLDatabaseLoader( + "SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b", + db=db, + page_content_mapper=row_to_content, + ) + docs = loader.load() + + assert len(docs) == 2 + assert docs[0].page_content == "a: 1" + assert docs[0].metadata == {} + + assert docs[1].page_content == "a: 3" + assert docs[1].metadata == {} + + +def test_sqldatabase_loader_metadata_columns(db: SQLDatabase) -> None: + """Test SQLAlchemy loader with defined metadata columns.""" + + # Define a custom callback function to convert a row into a "metadata" dictionary. + row_to_metadata = functools.partial( + SQLDatabaseLoader.metadata_default_mapper, column_names=["b"] + ) + + loader = SQLDatabaseLoader( + "SELECT 1 AS a, 2 AS b", + db=db, + metadata_mapper=row_to_metadata, + ) + docs = loader.load() + + assert len(docs) == 1 + assert docs[0].metadata == {"b": 2} + + +def test_sqldatabase_loader_real_data_with_sql_no_parameters( + db: SQLDatabase, provision_database: None +) -> None: + """Test SQLAlchemy loader with real data, querying by SQL statement.""" + + loader = SQLDatabaseLoader( + query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";', + db=db, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == {} + + +def test_sqldatabase_loader_real_data_with_sql_and_parameters( + db: SQLDatabase, provision_database: None +) -> None: + """Test SQLAlchemy loader, querying by SQL statement and parameters.""" + + loader = SQLDatabaseLoader( + query='SELECT * FROM mlb_teams_2012 WHERE "Team" LIKE :search ORDER BY "Team";', + parameters={"search": "R%"}, + db=db, + ) + docs = loader.load() + + assert len(docs) == 6 + assert docs[0].page_content == "Team: Rangers\nPayroll (millions): 120.51\nWins: 93" + assert docs[0].metadata == {} + + +def test_sqldatabase_loader_real_data_with_selectable( + db: SQLDatabase, provision_database: None +) -> None: + """Test SQLAlchemy loader with real data, querying by SQLAlchemy selectable.""" + + # Define an SQLAlchemy table. + mlb_teams_2012 = sa.Table( + "mlb_teams_2012", + sa.MetaData(), + sa.Column("Team", sa.VARCHAR), + sa.Column("Payroll (millions)", sa.FLOAT), + sa.Column("Wins", sa.BIGINT), + ) + + # Query the database table using an SQLAlchemy selectable. + select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team) + loader = SQLDatabaseLoader( + query=select, + db=db, + include_query_into_metadata=True, + ) + docs = loader.load() + + assert len(docs) == 30 + assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89" + assert docs[0].metadata == { + "query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", ' + 'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 ' + 'ORDER BY mlb_teams_2012."Team"' + } diff --git a/libs/community/tests/unit_tests/document_loaders/test_imports.py b/libs/community/tests/unit_tests/document_loaders/test_imports.py index d3c0b3b23d1..27f5a54d43c 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/community/tests/unit_tests/document_loaders/test_imports.py @@ -129,6 +129,7 @@ EXPECTED_ALL = [ "RocksetLoader", "S3DirectoryLoader", "S3FileLoader", + "SQLDatabaseLoader", "SRTLoader", "SeleniumURLLoader", "SharePointLoader",