community[minor]: Add SQLDatabaseLoader document loader (#18281)

- **Description:** A generic document loader adapter for SQLAlchemy on
top of LangChain's `SQLDatabaseLoader`.
  - **Needed by:** https://github.com/crate-workbench/langchain/pull/1
  - **Depends on:** GH-16655
  - **Addressed to:** @baskaryan, @cbornet, @eyurtsev

Hi from CrateDB again,

in the same spirit like GH-16243 and GH-16244, this patch breaks out
another commit from https://github.com/crate-workbench/langchain/pull/1,
in order to reduce the size of this patch before submitting it, and to
separate concerns.

To accompany the SQLAlchemy adapter implementation, the patch includes
integration tests for both SQLite and PostgreSQL. Let me know if
corresponding utility resources should be added at different spots.

With kind regards,
Andreas.


### Software Tests

```console
docker compose --file libs/community/tests/integration_tests/document_loaders/docker-compose/postgresql.yml up
```

```console
cd libs/community
pip install psycopg2-binary
pytest -vvv tests/integration_tests -k sqldatabase
```

```
14 passed
```



![image](https://github.com/langchain-ai/langchain/assets/453543/42be233c-eb37-4c76-a830-474276e01436)

---------

Co-authored-by: Andreas Motl <andreas.motl@crate.io>
This commit is contained in:
Eugene Yurtsev 2024-02-28 16:02:28 -05:00 committed by GitHub
parent a37dc83a9e
commit cd52433ba0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 539 additions and 0 deletions

View File

@ -32,3 +32,26 @@ services:
environment:
MONGO_INITDB_ROOT_USERNAME: langchain
MONGO_INITDB_ROOT_PASSWORD: langchain
postgres:
image: postgres:16
environment:
POSTGRES_DB: langchain
POSTGRES_USER: langchain
POSTGRES_PASSWORD: langchain
ports:
- "6023:5432"
command: |
postgres -c log_statement=all
healthcheck:
test:
[
"CMD-SHELL",
"psql postgresql://langchain:langchain@localhost/langchain --command 'SELECT 1;' || exit 1",
]
interval: 5s
retries: 60
volumes:
- postgres_data:/var/lib/postgresql/data
volumes:
postgres_data:

View File

@ -0,0 +1,40 @@
-- Provisioning table "mlb_teams_2012".
--
-- psql postgresql://postgres@localhost < mlb_teams_2012.sql
DROP TABLE IF EXISTS mlb_teams_2012;
CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT);
INSERT INTO mlb_teams_2012
("Team", "Payroll (millions)", "Wins")
VALUES
('Nationals', 81.34, 98),
('Reds', 82.20, 97),
('Yankees', 197.96, 95),
('Giants', 117.62, 94),
('Braves', 83.31, 94),
('Athletics', 55.37, 94),
('Rangers', 120.51, 93),
('Orioles', 81.43, 93),
('Rays', 64.17, 90),
('Angels', 154.49, 89),
('Tigers', 132.30, 88),
('Cardinals', 110.30, 88),
('Dodgers', 95.14, 86),
('White Sox', 96.92, 85),
('Brewers', 97.65, 83),
('Phillies', 174.54, 81),
('Diamondbacks', 74.28, 81),
('Pirates', 63.43, 79),
('Padres', 55.24, 76),
('Mariners', 81.97, 75),
('Mets', 93.35, 74),
('Blue Jays', 75.48, 73),
('Royals', 60.91, 72),
('Marlins', 118.07, 69),
('Red Sox', 173.18, 69),
('Indians', 78.43, 68),
('Twins', 94.08, 66),
('Rockies', 78.06, 64),
('Cubs', 88.19, 61),
('Astros', 60.65, 55)
;

View File

@ -187,6 +187,7 @@ from langchain_community.document_loaders.sitemap import SitemapLoader
from langchain_community.document_loaders.slack_directory import SlackDirectoryLoader
from langchain_community.document_loaders.snowflake_loader import SnowflakeLoader
from langchain_community.document_loaders.spreedly import SpreedlyLoader
from langchain_community.document_loaders.sql_database import SQLDatabaseLoader
from langchain_community.document_loaders.srt import SRTLoader
from langchain_community.document_loaders.stripe import StripeLoader
from langchain_community.document_loaders.surrealdb import SurrealDBLoader
@ -376,6 +377,7 @@ __all__ = [
"SlackDirectoryLoader",
"SnowflakeLoader",
"SpreedlyLoader",
"SQLDatabaseLoader",
"StripeLoader",
"SurrealDBLoader",
"TelegramChatApiLoader",

View File

@ -0,0 +1,139 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
import sqlalchemy as sa
from langchain_community.docstore.document import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.sql_database import SQLDatabase
class SQLDatabaseLoader(BaseLoader):
"""
Load documents by querying database tables supported by SQLAlchemy.
For talking to the database, the document loader uses the `SQLDatabase`
utility from the LangChain integration toolkit.
Each document represents one row of the result.
"""
def __init__(
self,
query: Union[str, sa.Select],
db: SQLDatabase,
*,
parameters: Optional[Dict[str, Any]] = None,
page_content_mapper: Optional[Callable[..., str]] = None,
metadata_mapper: Optional[Callable[..., Dict[str, Any]]] = None,
source_columns: Optional[Sequence[str]] = None,
include_rownum_into_metadata: bool = False,
include_query_into_metadata: bool = False,
):
"""
Args:
query: The query to execute.
db: A LangChain `SQLDatabase`, wrapping an SQLAlchemy engine.
sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`.
parameters: Optional. Parameters to pass to the query.
page_content_mapper: Optional. Function to convert a row into a string
to use as the `page_content` of the document. By default, the loader
serializes the whole row into a string, including all columns.
metadata_mapper: Optional. Function to convert a row into a dictionary
to use as the `metadata` of the document. By default, no columns are
selected into the metadata dictionary.
source_columns: Optional. The names of the columns to use as the `source`
within the metadata dictionary.
include_rownum_into_metadata: Optional. Whether to include the row number
into the metadata dictionary. Default: False.
include_query_into_metadata: Optional. Whether to include the query
expression into the metadata dictionary. Default: False.
"""
self.query = query
self.db: SQLDatabase = db
self.parameters = parameters or {}
self.page_content_mapper = (
page_content_mapper or self.page_content_default_mapper
)
self.metadata_mapper = metadata_mapper or self.metadata_default_mapper
self.source_columns = source_columns
self.include_rownum_into_metadata = include_rownum_into_metadata
self.include_query_into_metadata = include_query_into_metadata
def lazy_load(self) -> Iterator[Document]:
try:
import sqlalchemy as sa
except ImportError:
raise ImportError(
"Could not import sqlalchemy python package. "
"Please install it with `pip install sqlalchemy`."
)
# Querying in `cursor` fetch mode will return an SQLAlchemy `Result` instance.
result: sa.Result[Any]
# Invoke the database query.
if isinstance(self.query, sa.SelectBase):
result = self.db._execute( # type: ignore[assignment]
self.query, fetch="cursor", parameters=self.parameters
)
query_sql = str(self.query.compile(bind=self.db._engine))
elif isinstance(self.query, str):
result = self.db._execute( # type: ignore[assignment]
sa.text(self.query), fetch="cursor", parameters=self.parameters
)
query_sql = self.query
else:
raise TypeError(f"Unable to process query of unknown type: {self.query}")
# Iterate database result rows and generate list of documents.
for i, row in enumerate(result.mappings()):
page_content = self.page_content_mapper(row)
metadata = self.metadata_mapper(row)
if self.include_rownum_into_metadata:
metadata["row"] = i
if self.include_query_into_metadata:
metadata["query"] = query_sql
source_values = []
for column, value in row.items():
if self.source_columns and column in self.source_columns:
source_values.append(value)
if source_values:
metadata["source"] = ",".join(source_values)
yield Document(page_content=page_content, metadata=metadata)
def load(self) -> List[Document]:
return list(self.lazy_load())
@staticmethod
def page_content_default_mapper(
row: sa.RowMapping, column_names: Optional[List[str]] = None
) -> str:
"""
A reasonable default function to convert a record into a "page content" string.
"""
if column_names is None:
column_names = list(row.keys())
return "\n".join(
f"{column}: {value}"
for column, value in row.items()
if column in column_names
)
@staticmethod
def metadata_default_mapper(
row: sa.RowMapping, column_names: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
A reasonable default function to convert a record into a "metadata" dictionary.
"""
if column_names is None:
return {}
metadata: Dict[str, Any] = {}
for column, value in row.items():
if column in column_names:
metadata[column] = value
return metadata

View File

@ -0,0 +1,10 @@
"""Module defines common test data."""
from pathlib import Path
_THIS_DIR = Path(__file__).parent
_EXAMPLES_DIR = _THIS_DIR / "examples"
# Paths to data files
MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv"
MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql"

View File

@ -0,0 +1,32 @@
"Team", "Payroll (millions)", "Wins"
"Nationals", 81.34, 98
"Reds", 82.20, 97
"Yankees", 197.96, 95
"Giants", 117.62, 94
"Braves", 83.31, 94
"Athletics", 55.37, 94
"Rangers", 120.51, 93
"Orioles", 81.43, 93
"Rays", 64.17, 90
"Angels", 154.49, 89
"Tigers", 132.30, 88
"Cardinals", 110.30, 88
"Dodgers", 95.14, 86
"White Sox", 96.92, 85
"Brewers", 97.65, 83
"Phillies", 174.54, 81
"Diamondbacks", 74.28, 81
"Pirates", 63.43, 79
"Padres", 55.24, 76
"Mariners", 81.97, 75
"Mets", 93.35, 74
"Blue Jays", 75.48, 73
"Royals", 60.91, 72
"Marlins", 118.07, 69
"Red Sox", 173.18, 69
"Indians", 78.43, 68
"Twins", 94.08, 66
"Rockies", 78.06, 64
"Cubs", 88.19, 61
"Astros", 60.65, 55
1 Team Payroll (millions) Wins
2 Nationals 81.34 98
3 Reds 82.20 97
4 Yankees 197.96 95
5 Giants 117.62 94
6 Braves 83.31 94
7 Athletics 55.37 94
8 Rangers 120.51 93
9 Orioles 81.43 93
10 Rays 64.17 90
11 Angels 154.49 89
12 Tigers 132.30 88
13 Cardinals 110.30 88
14 Dodgers 95.14 86
15 White Sox 96.92 85
16 Brewers 97.65 83
17 Phillies 174.54 81
18 Diamondbacks 74.28 81
19 Pirates 63.43 79
20 Padres 55.24 76
21 Mariners 81.97 75
22 Mets 93.35 74
23 Blue Jays 75.48 73
24 Royals 60.91 72
25 Marlins 118.07 69
26 Red Sox 173.18 69
27 Indians 78.43 68
28 Twins 94.08 66
29 Rockies 78.06 64
30 Cubs 88.19 61
31 Astros 60.65 55

View File

@ -0,0 +1,40 @@
-- Provisioning table "mlb_teams_2012".
--
-- psql postgresql://postgres@localhost < mlb_teams_2012.sql
DROP TABLE IF EXISTS mlb_teams_2012;
CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT);
INSERT INTO mlb_teams_2012
("Team", "Payroll (millions)", "Wins")
VALUES
('Nationals', 81.34, 98),
('Reds', 82.20, 97),
('Yankees', 197.96, 95),
('Giants', 117.62, 94),
('Braves', 83.31, 94),
('Athletics', 55.37, 94),
('Rangers', 120.51, 93),
('Orioles', 81.43, 93),
('Rays', 64.17, 90),
('Angels', 154.49, 89),
('Tigers', 132.30, 88),
('Cardinals', 110.30, 88),
('Dodgers', 95.14, 86),
('White Sox', 96.92, 85),
('Brewers', 97.65, 83),
('Phillies', 174.54, 81),
('Diamondbacks', 74.28, 81),
('Pirates', 63.43, 79),
('Padres', 55.24, 76),
('Mariners', 81.97, 75),
('Mets', 93.35, 74),
('Blue Jays', 75.48, 73),
('Royals', 60.91, 72),
('Marlins', 118.07, 69),
('Red Sox', 173.18, 69),
('Indians', 78.43, 68),
('Twins', 94.08, 66),
('Rockies', 78.06, 64),
('Cubs', 88.19, 61),
('Astros', 60.65, 55)
;

View File

@ -0,0 +1,252 @@
"""
Test SQLAlchemy document loader functionality on behalf of SQLite and PostgreSQL.
To run the tests for SQLite, you need to have the `sqlite3` package installed.
To run the tests for PostgreSQL, you need to have the `psycopg2` package installed.
In addition, to launch the PostgreSQL instance, you can use the docker compose file
located at the root of the repo, `langchain/docker/docker-compose.yml`. Use the
command `docker compose up postgres` to start the instance. It will have the
appropriate credentials set up including being exposed on the appropriate port.
"""
import functools
import logging
import typing
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
import sqlalchemy as sa
from langchain_community.utilities.sql_database import SQLDatabase
if typing.TYPE_CHECKING:
from _pytest.python import Metafunc
from langchain_community.document_loaders.sql_database import SQLDatabaseLoader
from tests.data import MLB_TEAMS_2012_SQL
logging.basicConfig(level=logging.DEBUG)
try:
import sqlite3 # noqa: F401
sqlite3_installed = True
except ImportError:
warnings.warn("sqlite3 not installed, skipping corresponding tests", UserWarning)
sqlite3_installed = False
try:
import psycopg2 # noqa: F401
psycopg2_installed = True
except ImportError:
warnings.warn("psycopg2 not installed, skipping corresponding tests", UserWarning)
psycopg2_installed = False
@pytest.fixture()
def engine(db_uri: str) -> sa.Engine:
"""
Return an SQLAlchemy engine object.
"""
return sa.create_engine(db_uri, echo=False)
@pytest.fixture()
def db(engine: sa.Engine) -> SQLDatabase:
return SQLDatabase(engine=engine)
@pytest.fixture()
def provision_database(engine: sa.Engine) -> None:
"""
Provision database with table schema and data.
"""
sql_statements = MLB_TEAMS_2012_SQL.read_text()
with engine.connect() as connection:
connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;"))
for statement in sql_statements.split(";"):
statement = statement.strip()
if not statement:
continue
connection.execute(sa.text(statement))
connection.commit()
tmpdir = TemporaryDirectory()
def pytest_generate_tests(metafunc: "Metafunc") -> None:
"""
Dynamically parameterize test cases to verify both SQLite and PostgreSQL.
"""
if "db_uri" in metafunc.fixturenames:
urls = []
ids = []
if sqlite3_installed:
db_path = Path(tmpdir.name).joinpath("testdrive.sqlite")
urls.append(f"sqlite:///{db_path}")
ids.append("sqlite")
if psycopg2_installed:
# We use non-standard port for testing purposes.
# The easiest way to spin up the PostgreSQL instance is to use
# the docker compose file at the root of the repo located at
# langchain/docker/docker-compose.yml
# use `docker compose up postgres` to start the instance
# it will have the appropriate credentials set up including
# being exposed on the appropriate port.
urls.append(
"postgresql+psycopg2://langchain:langchain@localhost:6023/langchain"
)
ids.append("postgresql")
metafunc.parametrize("db_uri", urls, ids=ids)
def test_sqldatabase_loader_no_options(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader basics."""
loader = SQLDatabaseLoader("SELECT 1 AS a, 2 AS b", db=db)
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == "a: 1\nb: 2"
assert docs[0].metadata == {}
def test_sqldatabase_loader_include_rownum_into_metadata(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with row number in metadata."""
loader = SQLDatabaseLoader(
"SELECT 1 AS a, 2 AS b",
db=db,
include_rownum_into_metadata=True,
)
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == "a: 1\nb: 2"
assert docs[0].metadata == {"row": 0}
def test_sqldatabase_loader_include_query_into_metadata(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with query in metadata."""
loader = SQLDatabaseLoader(
"SELECT 1 AS a, 2 AS b", db=db, include_query_into_metadata=True
)
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == "a: 1\nb: 2"
assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"}
def test_sqldatabase_loader_page_content_columns(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with defined page content columns."""
# Define a custom callback function to convert a row into a "page content" string.
row_to_content = functools.partial(
SQLDatabaseLoader.page_content_default_mapper, column_names=["a"]
)
loader = SQLDatabaseLoader(
"SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b",
db=db,
page_content_mapper=row_to_content,
)
docs = loader.load()
assert len(docs) == 2
assert docs[0].page_content == "a: 1"
assert docs[0].metadata == {}
assert docs[1].page_content == "a: 3"
assert docs[1].metadata == {}
def test_sqldatabase_loader_metadata_columns(db: SQLDatabase) -> None:
"""Test SQLAlchemy loader with defined metadata columns."""
# Define a custom callback function to convert a row into a "metadata" dictionary.
row_to_metadata = functools.partial(
SQLDatabaseLoader.metadata_default_mapper, column_names=["b"]
)
loader = SQLDatabaseLoader(
"SELECT 1 AS a, 2 AS b",
db=db,
metadata_mapper=row_to_metadata,
)
docs = loader.load()
assert len(docs) == 1
assert docs[0].metadata == {"b": 2}
def test_sqldatabase_loader_real_data_with_sql_no_parameters(
db: SQLDatabase, provision_database: None
) -> None:
"""Test SQLAlchemy loader with real data, querying by SQL statement."""
loader = SQLDatabaseLoader(
query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";',
db=db,
)
docs = loader.load()
assert len(docs) == 30
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
assert docs[0].metadata == {}
def test_sqldatabase_loader_real_data_with_sql_and_parameters(
db: SQLDatabase, provision_database: None
) -> None:
"""Test SQLAlchemy loader, querying by SQL statement and parameters."""
loader = SQLDatabaseLoader(
query='SELECT * FROM mlb_teams_2012 WHERE "Team" LIKE :search ORDER BY "Team";',
parameters={"search": "R%"},
db=db,
)
docs = loader.load()
assert len(docs) == 6
assert docs[0].page_content == "Team: Rangers\nPayroll (millions): 120.51\nWins: 93"
assert docs[0].metadata == {}
def test_sqldatabase_loader_real_data_with_selectable(
db: SQLDatabase, provision_database: None
) -> None:
"""Test SQLAlchemy loader with real data, querying by SQLAlchemy selectable."""
# Define an SQLAlchemy table.
mlb_teams_2012 = sa.Table(
"mlb_teams_2012",
sa.MetaData(),
sa.Column("Team", sa.VARCHAR),
sa.Column("Payroll (millions)", sa.FLOAT),
sa.Column("Wins", sa.BIGINT),
)
# Query the database table using an SQLAlchemy selectable.
select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team)
loader = SQLDatabaseLoader(
query=select,
db=db,
include_query_into_metadata=True,
)
docs = loader.load()
assert len(docs) == 30
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
assert docs[0].metadata == {
"query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", '
'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 '
'ORDER BY mlb_teams_2012."Team"'
}

View File

@ -129,6 +129,7 @@ EXPECTED_ALL = [
"RocksetLoader",
"S3DirectoryLoader",
"S3FileLoader",
"SQLDatabaseLoader",
"SRTLoader",
"SeleniumURLLoader",
"SharePointLoader",