mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-30 16:24:24 +00:00
community[minor]: Add SQLDatabaseLoader
document loader (#18281)
- **Description:** A generic document loader adapter for SQLAlchemy on top of LangChain's `SQLDatabaseLoader`. - **Needed by:** https://github.com/crate-workbench/langchain/pull/1 - **Depends on:** GH-16655 - **Addressed to:** @baskaryan, @cbornet, @eyurtsev Hi from CrateDB again, in the same spirit like GH-16243 and GH-16244, this patch breaks out another commit from https://github.com/crate-workbench/langchain/pull/1, in order to reduce the size of this patch before submitting it, and to separate concerns. To accompany the SQLAlchemy adapter implementation, the patch includes integration tests for both SQLite and PostgreSQL. Let me know if corresponding utility resources should be added at different spots. With kind regards, Andreas. ### Software Tests ```console docker compose --file libs/community/tests/integration_tests/document_loaders/docker-compose/postgresql.yml up ``` ```console cd libs/community pip install psycopg2-binary pytest -vvv tests/integration_tests -k sqldatabase ``` ``` 14 passed ```  --------- Co-authored-by: Andreas Motl <andreas.motl@crate.io>
This commit is contained in:
parent
a37dc83a9e
commit
cd52433ba0
@ -32,3 +32,26 @@ services:
|
||||
environment:
|
||||
MONGO_INITDB_ROOT_USERNAME: langchain
|
||||
MONGO_INITDB_ROOT_PASSWORD: langchain
|
||||
postgres:
|
||||
image: postgres:16
|
||||
environment:
|
||||
POSTGRES_DB: langchain
|
||||
POSTGRES_USER: langchain
|
||||
POSTGRES_PASSWORD: langchain
|
||||
ports:
|
||||
- "6023:5432"
|
||||
command: |
|
||||
postgres -c log_statement=all
|
||||
healthcheck:
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
"psql postgresql://langchain:langchain@localhost/langchain --command 'SELECT 1;' || exit 1",
|
||||
]
|
||||
interval: 5s
|
||||
retries: 60
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
|
@ -0,0 +1,40 @@
|
||||
-- Provisioning table "mlb_teams_2012".
|
||||
--
|
||||
-- psql postgresql://postgres@localhost < mlb_teams_2012.sql
|
||||
|
||||
DROP TABLE IF EXISTS mlb_teams_2012;
|
||||
CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT);
|
||||
INSERT INTO mlb_teams_2012
|
||||
("Team", "Payroll (millions)", "Wins")
|
||||
VALUES
|
||||
('Nationals', 81.34, 98),
|
||||
('Reds', 82.20, 97),
|
||||
('Yankees', 197.96, 95),
|
||||
('Giants', 117.62, 94),
|
||||
('Braves', 83.31, 94),
|
||||
('Athletics', 55.37, 94),
|
||||
('Rangers', 120.51, 93),
|
||||
('Orioles', 81.43, 93),
|
||||
('Rays', 64.17, 90),
|
||||
('Angels', 154.49, 89),
|
||||
('Tigers', 132.30, 88),
|
||||
('Cardinals', 110.30, 88),
|
||||
('Dodgers', 95.14, 86),
|
||||
('White Sox', 96.92, 85),
|
||||
('Brewers', 97.65, 83),
|
||||
('Phillies', 174.54, 81),
|
||||
('Diamondbacks', 74.28, 81),
|
||||
('Pirates', 63.43, 79),
|
||||
('Padres', 55.24, 76),
|
||||
('Mariners', 81.97, 75),
|
||||
('Mets', 93.35, 74),
|
||||
('Blue Jays', 75.48, 73),
|
||||
('Royals', 60.91, 72),
|
||||
('Marlins', 118.07, 69),
|
||||
('Red Sox', 173.18, 69),
|
||||
('Indians', 78.43, 68),
|
||||
('Twins', 94.08, 66),
|
||||
('Rockies', 78.06, 64),
|
||||
('Cubs', 88.19, 61),
|
||||
('Astros', 60.65, 55)
|
||||
;
|
@ -187,6 +187,7 @@ from langchain_community.document_loaders.sitemap import SitemapLoader
|
||||
from langchain_community.document_loaders.slack_directory import SlackDirectoryLoader
|
||||
from langchain_community.document_loaders.snowflake_loader import SnowflakeLoader
|
||||
from langchain_community.document_loaders.spreedly import SpreedlyLoader
|
||||
from langchain_community.document_loaders.sql_database import SQLDatabaseLoader
|
||||
from langchain_community.document_loaders.srt import SRTLoader
|
||||
from langchain_community.document_loaders.stripe import StripeLoader
|
||||
from langchain_community.document_loaders.surrealdb import SurrealDBLoader
|
||||
@ -376,6 +377,7 @@ __all__ = [
|
||||
"SlackDirectoryLoader",
|
||||
"SnowflakeLoader",
|
||||
"SpreedlyLoader",
|
||||
"SQLDatabaseLoader",
|
||||
"StripeLoader",
|
||||
"SurrealDBLoader",
|
||||
"TelegramChatApiLoader",
|
||||
|
@ -0,0 +1,139 @@
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from langchain_community.docstore.document import Document
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
|
||||
class SQLDatabaseLoader(BaseLoader):
|
||||
"""
|
||||
Load documents by querying database tables supported by SQLAlchemy.
|
||||
|
||||
For talking to the database, the document loader uses the `SQLDatabase`
|
||||
utility from the LangChain integration toolkit.
|
||||
|
||||
Each document represents one row of the result.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: Union[str, sa.Select],
|
||||
db: SQLDatabase,
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
page_content_mapper: Optional[Callable[..., str]] = None,
|
||||
metadata_mapper: Optional[Callable[..., Dict[str, Any]]] = None,
|
||||
source_columns: Optional[Sequence[str]] = None,
|
||||
include_rownum_into_metadata: bool = False,
|
||||
include_query_into_metadata: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query: The query to execute.
|
||||
db: A LangChain `SQLDatabase`, wrapping an SQLAlchemy engine.
|
||||
sqlalchemy_kwargs: More keyword arguments for SQLAlchemy's `create_engine`.
|
||||
parameters: Optional. Parameters to pass to the query.
|
||||
page_content_mapper: Optional. Function to convert a row into a string
|
||||
to use as the `page_content` of the document. By default, the loader
|
||||
serializes the whole row into a string, including all columns.
|
||||
metadata_mapper: Optional. Function to convert a row into a dictionary
|
||||
to use as the `metadata` of the document. By default, no columns are
|
||||
selected into the metadata dictionary.
|
||||
source_columns: Optional. The names of the columns to use as the `source`
|
||||
within the metadata dictionary.
|
||||
include_rownum_into_metadata: Optional. Whether to include the row number
|
||||
into the metadata dictionary. Default: False.
|
||||
include_query_into_metadata: Optional. Whether to include the query
|
||||
expression into the metadata dictionary. Default: False.
|
||||
"""
|
||||
self.query = query
|
||||
self.db: SQLDatabase = db
|
||||
self.parameters = parameters or {}
|
||||
self.page_content_mapper = (
|
||||
page_content_mapper or self.page_content_default_mapper
|
||||
)
|
||||
self.metadata_mapper = metadata_mapper or self.metadata_default_mapper
|
||||
self.source_columns = source_columns
|
||||
self.include_rownum_into_metadata = include_rownum_into_metadata
|
||||
self.include_query_into_metadata = include_query_into_metadata
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
try:
|
||||
import sqlalchemy as sa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sqlalchemy python package. "
|
||||
"Please install it with `pip install sqlalchemy`."
|
||||
)
|
||||
|
||||
# Querying in `cursor` fetch mode will return an SQLAlchemy `Result` instance.
|
||||
result: sa.Result[Any]
|
||||
|
||||
# Invoke the database query.
|
||||
if isinstance(self.query, sa.SelectBase):
|
||||
result = self.db._execute( # type: ignore[assignment]
|
||||
self.query, fetch="cursor", parameters=self.parameters
|
||||
)
|
||||
query_sql = str(self.query.compile(bind=self.db._engine))
|
||||
elif isinstance(self.query, str):
|
||||
result = self.db._execute( # type: ignore[assignment]
|
||||
sa.text(self.query), fetch="cursor", parameters=self.parameters
|
||||
)
|
||||
query_sql = self.query
|
||||
else:
|
||||
raise TypeError(f"Unable to process query of unknown type: {self.query}")
|
||||
|
||||
# Iterate database result rows and generate list of documents.
|
||||
for i, row in enumerate(result.mappings()):
|
||||
page_content = self.page_content_mapper(row)
|
||||
metadata = self.metadata_mapper(row)
|
||||
|
||||
if self.include_rownum_into_metadata:
|
||||
metadata["row"] = i
|
||||
if self.include_query_into_metadata:
|
||||
metadata["query"] = query_sql
|
||||
|
||||
source_values = []
|
||||
for column, value in row.items():
|
||||
if self.source_columns and column in self.source_columns:
|
||||
source_values.append(value)
|
||||
if source_values:
|
||||
metadata["source"] = ",".join(source_values)
|
||||
|
||||
yield Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
return list(self.lazy_load())
|
||||
|
||||
@staticmethod
|
||||
def page_content_default_mapper(
|
||||
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""
|
||||
A reasonable default function to convert a record into a "page content" string.
|
||||
"""
|
||||
if column_names is None:
|
||||
column_names = list(row.keys())
|
||||
return "\n".join(
|
||||
f"{column}: {value}"
|
||||
for column, value in row.items()
|
||||
if column in column_names
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def metadata_default_mapper(
|
||||
row: sa.RowMapping, column_names: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
A reasonable default function to convert a record into a "metadata" dictionary.
|
||||
"""
|
||||
if column_names is None:
|
||||
return {}
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
for column, value in row.items():
|
||||
if column in column_names:
|
||||
metadata[column] = value
|
||||
return metadata
|
10
libs/community/tests/data.py
Normal file
10
libs/community/tests/data.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""Module defines common test data."""
|
||||
from pathlib import Path
|
||||
|
||||
_THIS_DIR = Path(__file__).parent
|
||||
|
||||
_EXAMPLES_DIR = _THIS_DIR / "examples"
|
||||
|
||||
# Paths to data files
|
||||
MLB_TEAMS_2012_CSV = _EXAMPLES_DIR / "mlb_teams_2012.csv"
|
||||
MLB_TEAMS_2012_SQL = _EXAMPLES_DIR / "mlb_teams_2012.sql"
|
32
libs/community/tests/examples/mlb_teams_2012.csv
Normal file
32
libs/community/tests/examples/mlb_teams_2012.csv
Normal file
@ -0,0 +1,32 @@
|
||||
"Team", "Payroll (millions)", "Wins"
|
||||
"Nationals", 81.34, 98
|
||||
"Reds", 82.20, 97
|
||||
"Yankees", 197.96, 95
|
||||
"Giants", 117.62, 94
|
||||
"Braves", 83.31, 94
|
||||
"Athletics", 55.37, 94
|
||||
"Rangers", 120.51, 93
|
||||
"Orioles", 81.43, 93
|
||||
"Rays", 64.17, 90
|
||||
"Angels", 154.49, 89
|
||||
"Tigers", 132.30, 88
|
||||
"Cardinals", 110.30, 88
|
||||
"Dodgers", 95.14, 86
|
||||
"White Sox", 96.92, 85
|
||||
"Brewers", 97.65, 83
|
||||
"Phillies", 174.54, 81
|
||||
"Diamondbacks", 74.28, 81
|
||||
"Pirates", 63.43, 79
|
||||
"Padres", 55.24, 76
|
||||
"Mariners", 81.97, 75
|
||||
"Mets", 93.35, 74
|
||||
"Blue Jays", 75.48, 73
|
||||
"Royals", 60.91, 72
|
||||
"Marlins", 118.07, 69
|
||||
"Red Sox", 173.18, 69
|
||||
"Indians", 78.43, 68
|
||||
"Twins", 94.08, 66
|
||||
"Rockies", 78.06, 64
|
||||
"Cubs", 88.19, 61
|
||||
"Astros", 60.65, 55
|
||||
|
|
40
libs/community/tests/examples/mlb_teams_2012.sql
Normal file
40
libs/community/tests/examples/mlb_teams_2012.sql
Normal file
@ -0,0 +1,40 @@
|
||||
-- Provisioning table "mlb_teams_2012".
|
||||
--
|
||||
-- psql postgresql://postgres@localhost < mlb_teams_2012.sql
|
||||
|
||||
DROP TABLE IF EXISTS mlb_teams_2012;
|
||||
CREATE TABLE mlb_teams_2012 ("Team" VARCHAR, "Payroll (millions)" FLOAT, "Wins" BIGINT);
|
||||
INSERT INTO mlb_teams_2012
|
||||
("Team", "Payroll (millions)", "Wins")
|
||||
VALUES
|
||||
('Nationals', 81.34, 98),
|
||||
('Reds', 82.20, 97),
|
||||
('Yankees', 197.96, 95),
|
||||
('Giants', 117.62, 94),
|
||||
('Braves', 83.31, 94),
|
||||
('Athletics', 55.37, 94),
|
||||
('Rangers', 120.51, 93),
|
||||
('Orioles', 81.43, 93),
|
||||
('Rays', 64.17, 90),
|
||||
('Angels', 154.49, 89),
|
||||
('Tigers', 132.30, 88),
|
||||
('Cardinals', 110.30, 88),
|
||||
('Dodgers', 95.14, 86),
|
||||
('White Sox', 96.92, 85),
|
||||
('Brewers', 97.65, 83),
|
||||
('Phillies', 174.54, 81),
|
||||
('Diamondbacks', 74.28, 81),
|
||||
('Pirates', 63.43, 79),
|
||||
('Padres', 55.24, 76),
|
||||
('Mariners', 81.97, 75),
|
||||
('Mets', 93.35, 74),
|
||||
('Blue Jays', 75.48, 73),
|
||||
('Royals', 60.91, 72),
|
||||
('Marlins', 118.07, 69),
|
||||
('Red Sox', 173.18, 69),
|
||||
('Indians', 78.43, 68),
|
||||
('Twins', 94.08, 66),
|
||||
('Rockies', 78.06, 64),
|
||||
('Cubs', 88.19, 61),
|
||||
('Astros', 60.65, 55)
|
||||
;
|
@ -0,0 +1,252 @@
|
||||
"""
|
||||
Test SQLAlchemy document loader functionality on behalf of SQLite and PostgreSQL.
|
||||
|
||||
To run the tests for SQLite, you need to have the `sqlite3` package installed.
|
||||
|
||||
To run the tests for PostgreSQL, you need to have the `psycopg2` package installed.
|
||||
In addition, to launch the PostgreSQL instance, you can use the docker compose file
|
||||
located at the root of the repo, `langchain/docker/docker-compose.yml`. Use the
|
||||
command `docker compose up postgres` to start the instance. It will have the
|
||||
appropriate credentials set up including being exposed on the appropriate port.
|
||||
"""
|
||||
import functools
|
||||
import logging
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from _pytest.python import Metafunc
|
||||
|
||||
from langchain_community.document_loaders.sql_database import SQLDatabaseLoader
|
||||
from tests.data import MLB_TEAMS_2012_SQL
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
try:
|
||||
import sqlite3 # noqa: F401
|
||||
|
||||
sqlite3_installed = True
|
||||
except ImportError:
|
||||
warnings.warn("sqlite3 not installed, skipping corresponding tests", UserWarning)
|
||||
sqlite3_installed = False
|
||||
|
||||
try:
|
||||
import psycopg2 # noqa: F401
|
||||
|
||||
psycopg2_installed = True
|
||||
except ImportError:
|
||||
warnings.warn("psycopg2 not installed, skipping corresponding tests", UserWarning)
|
||||
psycopg2_installed = False
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def engine(db_uri: str) -> sa.Engine:
|
||||
"""
|
||||
Return an SQLAlchemy engine object.
|
||||
"""
|
||||
return sa.create_engine(db_uri, echo=False)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db(engine: sa.Engine) -> SQLDatabase:
|
||||
return SQLDatabase(engine=engine)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def provision_database(engine: sa.Engine) -> None:
|
||||
"""
|
||||
Provision database with table schema and data.
|
||||
"""
|
||||
sql_statements = MLB_TEAMS_2012_SQL.read_text()
|
||||
with engine.connect() as connection:
|
||||
connection.execute(sa.text("DROP TABLE IF EXISTS mlb_teams_2012;"))
|
||||
for statement in sql_statements.split(";"):
|
||||
statement = statement.strip()
|
||||
if not statement:
|
||||
continue
|
||||
connection.execute(sa.text(statement))
|
||||
connection.commit()
|
||||
|
||||
|
||||
tmpdir = TemporaryDirectory()
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: "Metafunc") -> None:
|
||||
"""
|
||||
Dynamically parameterize test cases to verify both SQLite and PostgreSQL.
|
||||
"""
|
||||
if "db_uri" in metafunc.fixturenames:
|
||||
urls = []
|
||||
ids = []
|
||||
if sqlite3_installed:
|
||||
db_path = Path(tmpdir.name).joinpath("testdrive.sqlite")
|
||||
urls.append(f"sqlite:///{db_path}")
|
||||
ids.append("sqlite")
|
||||
if psycopg2_installed:
|
||||
# We use non-standard port for testing purposes.
|
||||
# The easiest way to spin up the PostgreSQL instance is to use
|
||||
# the docker compose file at the root of the repo located at
|
||||
# langchain/docker/docker-compose.yml
|
||||
# use `docker compose up postgres` to start the instance
|
||||
# it will have the appropriate credentials set up including
|
||||
# being exposed on the appropriate port.
|
||||
urls.append(
|
||||
"postgresql+psycopg2://langchain:langchain@localhost:6023/langchain"
|
||||
)
|
||||
ids.append("postgresql")
|
||||
|
||||
metafunc.parametrize("db_uri", urls, ids=ids)
|
||||
|
||||
|
||||
def test_sqldatabase_loader_no_options(db: SQLDatabase) -> None:
|
||||
"""Test SQLAlchemy loader basics."""
|
||||
|
||||
loader = SQLDatabaseLoader("SELECT 1 AS a, 2 AS b", db=db)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "a: 1\nb: 2"
|
||||
assert docs[0].metadata == {}
|
||||
|
||||
|
||||
def test_sqldatabase_loader_include_rownum_into_metadata(db: SQLDatabase) -> None:
|
||||
"""Test SQLAlchemy loader with row number in metadata."""
|
||||
|
||||
loader = SQLDatabaseLoader(
|
||||
"SELECT 1 AS a, 2 AS b",
|
||||
db=db,
|
||||
include_rownum_into_metadata=True,
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "a: 1\nb: 2"
|
||||
assert docs[0].metadata == {"row": 0}
|
||||
|
||||
|
||||
def test_sqldatabase_loader_include_query_into_metadata(db: SQLDatabase) -> None:
|
||||
"""Test SQLAlchemy loader with query in metadata."""
|
||||
|
||||
loader = SQLDatabaseLoader(
|
||||
"SELECT 1 AS a, 2 AS b", db=db, include_query_into_metadata=True
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "a: 1\nb: 2"
|
||||
assert docs[0].metadata == {"query": "SELECT 1 AS a, 2 AS b"}
|
||||
|
||||
|
||||
def test_sqldatabase_loader_page_content_columns(db: SQLDatabase) -> None:
|
||||
"""Test SQLAlchemy loader with defined page content columns."""
|
||||
|
||||
# Define a custom callback function to convert a row into a "page content" string.
|
||||
row_to_content = functools.partial(
|
||||
SQLDatabaseLoader.page_content_default_mapper, column_names=["a"]
|
||||
)
|
||||
|
||||
loader = SQLDatabaseLoader(
|
||||
"SELECT 1 AS a, 2 AS b UNION SELECT 3 AS a, 4 AS b",
|
||||
db=db,
|
||||
page_content_mapper=row_to_content,
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == "a: 1"
|
||||
assert docs[0].metadata == {}
|
||||
|
||||
assert docs[1].page_content == "a: 3"
|
||||
assert docs[1].metadata == {}
|
||||
|
||||
|
||||
def test_sqldatabase_loader_metadata_columns(db: SQLDatabase) -> None:
|
||||
"""Test SQLAlchemy loader with defined metadata columns."""
|
||||
|
||||
# Define a custom callback function to convert a row into a "metadata" dictionary.
|
||||
row_to_metadata = functools.partial(
|
||||
SQLDatabaseLoader.metadata_default_mapper, column_names=["b"]
|
||||
)
|
||||
|
||||
loader = SQLDatabaseLoader(
|
||||
"SELECT 1 AS a, 2 AS b",
|
||||
db=db,
|
||||
metadata_mapper=row_to_metadata,
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata == {"b": 2}
|
||||
|
||||
|
||||
def test_sqldatabase_loader_real_data_with_sql_no_parameters(
|
||||
db: SQLDatabase, provision_database: None
|
||||
) -> None:
|
||||
"""Test SQLAlchemy loader with real data, querying by SQL statement."""
|
||||
|
||||
loader = SQLDatabaseLoader(
|
||||
query='SELECT * FROM mlb_teams_2012 ORDER BY "Team";',
|
||||
db=db,
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 30
|
||||
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
|
||||
assert docs[0].metadata == {}
|
||||
|
||||
|
||||
def test_sqldatabase_loader_real_data_with_sql_and_parameters(
|
||||
db: SQLDatabase, provision_database: None
|
||||
) -> None:
|
||||
"""Test SQLAlchemy loader, querying by SQL statement and parameters."""
|
||||
|
||||
loader = SQLDatabaseLoader(
|
||||
query='SELECT * FROM mlb_teams_2012 WHERE "Team" LIKE :search ORDER BY "Team";',
|
||||
parameters={"search": "R%"},
|
||||
db=db,
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 6
|
||||
assert docs[0].page_content == "Team: Rangers\nPayroll (millions): 120.51\nWins: 93"
|
||||
assert docs[0].metadata == {}
|
||||
|
||||
|
||||
def test_sqldatabase_loader_real_data_with_selectable(
|
||||
db: SQLDatabase, provision_database: None
|
||||
) -> None:
|
||||
"""Test SQLAlchemy loader with real data, querying by SQLAlchemy selectable."""
|
||||
|
||||
# Define an SQLAlchemy table.
|
||||
mlb_teams_2012 = sa.Table(
|
||||
"mlb_teams_2012",
|
||||
sa.MetaData(),
|
||||
sa.Column("Team", sa.VARCHAR),
|
||||
sa.Column("Payroll (millions)", sa.FLOAT),
|
||||
sa.Column("Wins", sa.BIGINT),
|
||||
)
|
||||
|
||||
# Query the database table using an SQLAlchemy selectable.
|
||||
select = sa.select(mlb_teams_2012).order_by(mlb_teams_2012.c.Team)
|
||||
loader = SQLDatabaseLoader(
|
||||
query=select,
|
||||
db=db,
|
||||
include_query_into_metadata=True,
|
||||
)
|
||||
docs = loader.load()
|
||||
|
||||
assert len(docs) == 30
|
||||
assert docs[0].page_content == "Team: Angels\nPayroll (millions): 154.49\nWins: 89"
|
||||
assert docs[0].metadata == {
|
||||
"query": 'SELECT mlb_teams_2012."Team", mlb_teams_2012."Payroll (millions)", '
|
||||
'mlb_teams_2012."Wins" \nFROM mlb_teams_2012 '
|
||||
'ORDER BY mlb_teams_2012."Team"'
|
||||
}
|
@ -129,6 +129,7 @@ EXPECTED_ALL = [
|
||||
"RocksetLoader",
|
||||
"S3DirectoryLoader",
|
||||
"S3FileLoader",
|
||||
"SQLDatabaseLoader",
|
||||
"SRTLoader",
|
||||
"SeleniumURLLoader",
|
||||
"SharePointLoader",
|
||||
|
Loading…
Reference in New Issue
Block a user