diff --git a/docs/extras/integrations/document_loaders/polars_dataframe.ipynb b/docs/extras/integrations/document_loaders/polars_dataframe.ipynb new file mode 100644 index 00000000000..52936f16540 --- /dev/null +++ b/docs/extras/integrations/document_loaders/polars_dataframe.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "213a38a2", + "metadata": {}, + "source": [ + "# Polars DataFrame\n", + "\n", + "This notebook goes over how to load data from a [polars](https://pola-rs.github.io/polars-book/user-guide/) DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f6a7a9e4-80d6-486a-b2e3-636c568aa97c", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install polars" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "79331964", + "metadata": {}, + "outputs": [], + "source": [ + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e487044c", + "metadata": {}, + "outputs": [], + "source": [ + "df = pl.read_csv(\"example_data/mlb_teams_2012.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ac273ca1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
Team "Payroll (millions)" "Wins"
strf64i64
"Nationals"81.3498
"Reds"82.297
"Yankees"197.9695
"Giants"117.6294
"Braves"83.3194
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌───────────┬───────────────────────┬─────────┐\n", + "│ Team ┆ \"Payroll (millions)\" ┆ \"Wins\" │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ str ┆ f64 ┆ i64 │\n", + "╞═══════════╪═══════════════════════╪═════════╡\n", + "│ Nationals ┆ 81.34 ┆ 98 │\n", + "│ Reds ┆ 82.2 ┆ 97 │\n", + "│ Yankees ┆ 197.96 ┆ 95 │\n", + "│ Giants ┆ 117.62 ┆ 94 │\n", + "│ Braves ┆ 83.31 ┆ 94 │\n", + "└───────────┴───────────────────────┴─────────┘" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "66e47a13", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import PolarsDataFrameLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2334caca", + "metadata": {}, + "outputs": [], + "source": [ + "loader = PolarsDataFrameLoader(df, page_content_column=\"Team\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d616c2b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='Nationals', metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}),\n", + " Document(page_content='Reds', metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}),\n", + " Document(page_content='Yankees', metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}),\n", + " Document(page_content='Giants', metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}),\n", + " Document(page_content='Braves', metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}),\n", + " Document(page_content='Athletics', metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}),\n", + " Document(page_content='Rangers', metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}),\n", + " Document(page_content='Orioles', metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}),\n", + " Document(page_content='Rays', metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}),\n", + " Document(page_content='Angels', metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}),\n", + " Document(page_content='Tigers', metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}),\n", + " Document(page_content='Cardinals', metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}),\n", + " Document(page_content='Dodgers', metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}),\n", + " Document(page_content='White Sox', metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}),\n", + " Document(page_content='Brewers', metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}),\n", + " Document(page_content='Phillies', metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}),\n", + " Document(page_content='Diamondbacks', metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}),\n", + " Document(page_content='Pirates', metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}),\n", + " Document(page_content='Padres', metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}),\n", + " Document(page_content='Mariners', metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}),\n", + " Document(page_content='Mets', metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}),\n", + " Document(page_content='Blue Jays', metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}),\n", + " Document(page_content='Royals', metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}),\n", + " Document(page_content='Marlins', metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}),\n", + " Document(page_content='Red Sox', metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}),\n", + " Document(page_content='Indians', metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}),\n", + " Document(page_content='Twins', metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}),\n", + " Document(page_content='Rockies', metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}),\n", + " Document(page_content='Cubs', metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}),\n", + " Document(page_content='Astros', metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55})]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "beb55c2f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "page_content='Nationals' metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}\n", + "page_content='Reds' metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}\n", + "page_content='Yankees' metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}\n", + "page_content='Giants' metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}\n", + "page_content='Braves' metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}\n", + "page_content='Athletics' metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}\n", + "page_content='Rangers' metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}\n", + "page_content='Orioles' metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}\n", + "page_content='Rays' metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}\n", + "page_content='Angels' metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}\n", + "page_content='Tigers' metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}\n", + "page_content='Cardinals' metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}\n", + "page_content='Dodgers' metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}\n", + "page_content='White Sox' metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}\n", + "page_content='Brewers' metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}\n", + "page_content='Phillies' metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}\n", + "page_content='Diamondbacks' metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}\n", + "page_content='Pirates' metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}\n", + "page_content='Padres' metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}\n", + "page_content='Mariners' metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}\n", + "page_content='Mets' metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}\n", + "page_content='Blue Jays' metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}\n", + "page_content='Royals' metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}\n", + "page_content='Marlins' metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}\n", + "page_content='Red Sox' metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}\n", + "page_content='Indians' metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}\n", + "page_content='Twins' metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}\n", + "page_content='Rockies' metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}\n", + "page_content='Cubs' metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}\n", + "page_content='Astros' metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55}\n" + ] + } + ], + "source": [ + "# Use lazy load for larger table, which won't read the full table into memory\n", + "for i in loader.lazy_load():\n", + " print(i)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index 195b586ac2a..30f69659cee 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -132,6 +132,7 @@ from langchain.document_loaders.pdf import ( PyPDFLoader, UnstructuredPDFLoader, ) +from langchain.document_loaders.polars_dataframe import PolarsDataFrameLoader from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader from langchain.document_loaders.psychic import PsychicLoader from langchain.document_loaders.pubmed import PubMedLoader @@ -299,6 +300,7 @@ __all__ = [ "PDFPlumberLoader", "PagedPDFSplitter", "PlaywrightURLLoader", + "PolarsDataFrameLoader", "PsychicLoader", "PubMedLoader", "PyMuPDFLoader", diff --git a/libs/langchain/langchain/document_loaders/dataframe.py b/libs/langchain/langchain/document_loaders/dataframe.py index 0476f6a2986..261426a3ceb 100644 --- a/libs/langchain/langchain/document_loaders/dataframe.py +++ b/libs/langchain/langchain/document_loaders/dataframe.py @@ -4,23 +4,15 @@ from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -class DataFrameLoader(BaseLoader): - """Load `Pandas` DataFrame.""" - - def __init__(self, data_frame: Any, page_content_column: str = "text"): +class BaseDataFrameLoader(BaseLoader): + def __init__(self, data_frame: Any, *, page_content_column: str = "text"): """Initialize with dataframe object. Args: - data_frame: Pandas DataFrame object. + data_frame: DataFrame object. page_content_column: Name of the column containing the page content. Defaults to "text". """ - import pandas as pd - - if not isinstance(data_frame, pd.DataFrame): - raise ValueError( - f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}" - ) self.data_frame = data_frame self.page_content_column = page_content_column @@ -36,3 +28,28 @@ class DataFrameLoader(BaseLoader): def load(self) -> List[Document]: """Load full dataframe.""" return list(self.lazy_load()) + + +class DataFrameLoader(BaseDataFrameLoader): + """Load `Pandas` DataFrame.""" + + def __init__(self, data_frame: Any, page_content_column: str = "text"): + """Initialize with dataframe object. + + Args: + data_frame: Pandas DataFrame object. + page_content_column: Name of the column containing the page content. + Defaults to "text". + """ + try: + import pandas as pd + except ImportError as e: + raise ImportError( + "Unable to import pandas, please install with `pip install pandas`." + ) from e + + if not isinstance(data_frame, pd.DataFrame): + raise ValueError( + f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}" + ) + super().__init__(data_frame, page_content_column=page_content_column) diff --git a/libs/langchain/langchain/document_loaders/polars_dataframe.py b/libs/langchain/langchain/document_loaders/polars_dataframe.py new file mode 100644 index 00000000000..6ece942df49 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/polars_dataframe.py @@ -0,0 +1,32 @@ +from typing import Any, Iterator + +from langchain.docstore.document import Document +from langchain.document_loaders.dataframe import BaseDataFrameLoader + + +class PolarsDataFrameLoader(BaseDataFrameLoader): + """Load `Polars` DataFrame.""" + + def __init__(self, data_frame: Any, *, page_content_column: str = "text"): + """Initialize with dataframe object. + + Args: + data_frame: Polars DataFrame object. + page_content_column: Name of the column containing the page content. + Defaults to "text". + """ + import polars as pl + + if not isinstance(data_frame, pl.DataFrame): + raise ValueError( + f"Expected data_frame to be a pl.DataFrame, got {type(data_frame)}" + ) + super().__init__(data_frame, page_content_column=page_content_column) + + def lazy_load(self) -> Iterator[Document]: + """Lazy load records from dataframe.""" + + for row in self.data_frame.iter_rows(named=True): + text = row[self.page_content_column] + row.pop(self.page_content_column) + yield Document(page_content=text, metadata=row) diff --git a/libs/langchain/langchain/document_loaders/xorbits.py b/libs/langchain/langchain/document_loaders/xorbits.py index bcc4e680f68..723e9dc1b59 100644 --- a/libs/langchain/langchain/document_loaders/xorbits.py +++ b/libs/langchain/langchain/document_loaders/xorbits.py @@ -1,10 +1,9 @@ -from typing import Any, Iterator, List +from typing import Any -from langchain.docstore.document import Document -from langchain.document_loaders.base import BaseLoader +from langchain.document_loaders.dataframe import BaseDataFrameLoader -class XorbitsLoader(BaseLoader): +class XorbitsLoader(BaseDataFrameLoader): """Load `Xorbits` DataFrame.""" def __init__(self, data_frame: Any, page_content_column: str = "text"): @@ -30,17 +29,4 @@ class XorbitsLoader(BaseLoader): f"Expected data_frame to be a xorbits.pandas.DataFrame, \ got {type(data_frame)}" ) - self.data_frame = data_frame - self.page_content_column = page_content_column - - def lazy_load(self) -> Iterator[Document]: - """Lazy load records from dataframe.""" - for _, row in self.data_frame.iterrows(): - text = row[self.page_content_column] - metadata = row.to_dict() - metadata.pop(self.page_content_column) - yield Document(page_content=text, metadata=metadata) - - def load(self) -> List[Document]: - """Load full dataframe.""" - return list(self.lazy_load()) + super().__init__(data_frame, page_content_column=page_content_column) diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py new file mode 100644 index 00000000000..bd8e129dafa --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py @@ -0,0 +1,48 @@ +import polars as pl +import pytest + +from langchain.document_loaders import PolarsDataFrameLoader +from langchain.schema import Document + + +@pytest.fixture +def sample_data_frame() -> pl.DataFrame: + data = { + "text": ["Hello", "World"], + "author": ["Alice", "Bob"], + "date": ["2022-01-01", "2022-01-02"], + } + return pl.DataFrame(data) + + +def test_load_returns_list_of_documents(sample_data_frame: pl.DataFrame) -> None: + loader = PolarsDataFrameLoader(sample_data_frame) + docs = loader.load() + assert isinstance(docs, list) + assert all(isinstance(doc, Document) for doc in docs) + assert len(docs) == 2 + + +def test_load_converts_dataframe_columns_to_document_metadata( + sample_data_frame: pl.DataFrame, +) -> None: + loader = PolarsDataFrameLoader(sample_data_frame) + docs = loader.load() + + for i, doc in enumerate(docs): + df: pl.DataFrame = sample_data_frame[i] + assert df is not None + assert doc.metadata["author"] == df.select("author").item() + assert doc.metadata["date"] == df.select("date").item() + + +def test_load_uses_page_content_column_to_create_document_text( + sample_data_frame: pl.DataFrame, +) -> None: + sample_data_frame = sample_data_frame.rename(mapping={"text": "dummy_test_column"}) + loader = PolarsDataFrameLoader( + sample_data_frame, page_content_column="dummy_test_column" + ) + docs = loader.load() + assert docs[0].page_content == "Hello" + assert docs[1].page_content == "World"