From fba29f203adfe35d14a03511c88661ddd1c76263 Mon Sep 17 00:00:00 2001 From: toddkim95 <42592581+toddkim95@users.noreply.github.com> Date: Tue, 22 Aug 2023 23:36:24 +0900 Subject: [PATCH] Add to support polars (#9610) ### Description Polars is a DataFrame interface on top of an OLAP Query Engine implemented in Rust. Polars is faster to read than pandas, so I'm looking forward to seeing it added to the document loader. ### Dependencies polars (https://pola-rs.github.io/polars-book/user-guide/) --------- Co-authored-by: Bagatur --- .../document_loaders/polars_dataframe.ipynb | 225 ++++++++++++++++++ .../langchain/document_loaders/__init__.py | 2 + .../langchain/document_loaders/dataframe.py | 39 ++- .../document_loaders/polars_dataframe.py | 32 +++ .../langchain/document_loaders/xorbits.py | 22 +- .../document_loaders/test_polars_dataframe.py | 48 ++++ 6 files changed, 339 insertions(+), 29 deletions(-) create mode 100644 docs/extras/integrations/document_loaders/polars_dataframe.ipynb create mode 100644 libs/langchain/langchain/document_loaders/polars_dataframe.py create mode 100644 libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py diff --git a/docs/extras/integrations/document_loaders/polars_dataframe.ipynb b/docs/extras/integrations/document_loaders/polars_dataframe.ipynb new file mode 100644 index 00000000000..52936f16540 --- /dev/null +++ b/docs/extras/integrations/document_loaders/polars_dataframe.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "213a38a2", + "metadata": {}, + "source": [ + "# Polars DataFrame\n", + "\n", + "This notebook goes over how to load data from a [polars](https://pola-rs.github.io/polars-book/user-guide/) DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f6a7a9e4-80d6-486a-b2e3-636c568aa97c", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install polars" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "79331964", + "metadata": {}, + "outputs": [], + "source": [ + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e487044c", + "metadata": {}, + "outputs": [], + "source": [ + "df = pl.read_csv(\"example_data/mlb_teams_2012.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ac273ca1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
Team "Payroll (millions)" "Wins"
strf64i64
"Nationals"81.3498
"Reds"82.297
"Yankees"197.9695
"Giants"117.6294
"Braves"83.3194
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌───────────┬───────────────────────┬─────────┐\n", + "│ Team ┆ \"Payroll (millions)\" ┆ \"Wins\" │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ str ┆ f64 ┆ i64 │\n", + "╞═══════════╪═══════════════════════╪═════════╡\n", + "│ Nationals ┆ 81.34 ┆ 98 │\n", + "│ Reds ┆ 82.2 ┆ 97 │\n", + "│ Yankees ┆ 197.96 ┆ 95 │\n", + "│ Giants ┆ 117.62 ┆ 94 │\n", + "│ Braves ┆ 83.31 ┆ 94 │\n", + "└───────────┴───────────────────────┴─────────┘" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "66e47a13", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import PolarsDataFrameLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2334caca", + "metadata": {}, + "outputs": [], + "source": [ + "loader = PolarsDataFrameLoader(df, page_content_column=\"Team\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d616c2b0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='Nationals', metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}),\n", + " Document(page_content='Reds', metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}),\n", + " Document(page_content='Yankees', metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}),\n", + " Document(page_content='Giants', metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}),\n", + " Document(page_content='Braves', metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}),\n", + " Document(page_content='Athletics', metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}),\n", + " Document(page_content='Rangers', metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}),\n", + " Document(page_content='Orioles', metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}),\n", + " Document(page_content='Rays', metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}),\n", + " Document(page_content='Angels', metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}),\n", + " Document(page_content='Tigers', metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}),\n", + " Document(page_content='Cardinals', metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}),\n", + " Document(page_content='Dodgers', metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}),\n", + " Document(page_content='White Sox', metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}),\n", + " Document(page_content='Brewers', metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}),\n", + " Document(page_content='Phillies', metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}),\n", + " Document(page_content='Diamondbacks', metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}),\n", + " Document(page_content='Pirates', metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}),\n", + " Document(page_content='Padres', metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}),\n", + " Document(page_content='Mariners', metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}),\n", + " Document(page_content='Mets', metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}),\n", + " Document(page_content='Blue Jays', metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}),\n", + " Document(page_content='Royals', metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}),\n", + " Document(page_content='Marlins', metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}),\n", + " Document(page_content='Red Sox', metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}),\n", + " Document(page_content='Indians', metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}),\n", + " Document(page_content='Twins', metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}),\n", + " Document(page_content='Rockies', metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}),\n", + " Document(page_content='Cubs', metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}),\n", + " Document(page_content='Astros', metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55})]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "beb55c2f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "page_content='Nationals' metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}\n", + "page_content='Reds' metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}\n", + "page_content='Yankees' metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}\n", + "page_content='Giants' metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}\n", + "page_content='Braves' metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}\n", + "page_content='Athletics' metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}\n", + "page_content='Rangers' metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}\n", + "page_content='Orioles' metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}\n", + "page_content='Rays' metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}\n", + "page_content='Angels' metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}\n", + "page_content='Tigers' metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}\n", + "page_content='Cardinals' metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}\n", + "page_content='Dodgers' metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}\n", + "page_content='White Sox' metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}\n", + "page_content='Brewers' metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}\n", + "page_content='Phillies' metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}\n", + "page_content='Diamondbacks' metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}\n", + "page_content='Pirates' metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}\n", + "page_content='Padres' metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}\n", + "page_content='Mariners' metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}\n", + "page_content='Mets' metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}\n", + "page_content='Blue Jays' metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}\n", + "page_content='Royals' metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}\n", + "page_content='Marlins' metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}\n", + "page_content='Red Sox' metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}\n", + "page_content='Indians' metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}\n", + "page_content='Twins' metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}\n", + "page_content='Rockies' metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}\n", + "page_content='Cubs' metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}\n", + "page_content='Astros' metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55}\n" + ] + } + ], + "source": [ + "# Use lazy load for larger table, which won't read the full table into memory\n", + "for i in loader.lazy_load():\n", + " print(i)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index 195b586ac2a..30f69659cee 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -132,6 +132,7 @@ from langchain.document_loaders.pdf import ( PyPDFLoader, UnstructuredPDFLoader, ) +from langchain.document_loaders.polars_dataframe import PolarsDataFrameLoader from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader from langchain.document_loaders.psychic import PsychicLoader from langchain.document_loaders.pubmed import PubMedLoader @@ -299,6 +300,7 @@ __all__ = [ "PDFPlumberLoader", "PagedPDFSplitter", "PlaywrightURLLoader", + "PolarsDataFrameLoader", "PsychicLoader", "PubMedLoader", "PyMuPDFLoader", diff --git a/libs/langchain/langchain/document_loaders/dataframe.py b/libs/langchain/langchain/document_loaders/dataframe.py index 0476f6a2986..261426a3ceb 100644 --- a/libs/langchain/langchain/document_loaders/dataframe.py +++ b/libs/langchain/langchain/document_loaders/dataframe.py @@ -4,23 +4,15 @@ from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -class DataFrameLoader(BaseLoader): - """Load `Pandas` DataFrame.""" - - def __init__(self, data_frame: Any, page_content_column: str = "text"): +class BaseDataFrameLoader(BaseLoader): + def __init__(self, data_frame: Any, *, page_content_column: str = "text"): """Initialize with dataframe object. Args: - data_frame: Pandas DataFrame object. + data_frame: DataFrame object. page_content_column: Name of the column containing the page content. Defaults to "text". """ - import pandas as pd - - if not isinstance(data_frame, pd.DataFrame): - raise ValueError( - f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}" - ) self.data_frame = data_frame self.page_content_column = page_content_column @@ -36,3 +28,28 @@ class DataFrameLoader(BaseLoader): def load(self) -> List[Document]: """Load full dataframe.""" return list(self.lazy_load()) + + +class DataFrameLoader(BaseDataFrameLoader): + """Load `Pandas` DataFrame.""" + + def __init__(self, data_frame: Any, page_content_column: str = "text"): + """Initialize with dataframe object. + + Args: + data_frame: Pandas DataFrame object. + page_content_column: Name of the column containing the page content. + Defaults to "text". + """ + try: + import pandas as pd + except ImportError as e: + raise ImportError( + "Unable to import pandas, please install with `pip install pandas`." + ) from e + + if not isinstance(data_frame, pd.DataFrame): + raise ValueError( + f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}" + ) + super().__init__(data_frame, page_content_column=page_content_column) diff --git a/libs/langchain/langchain/document_loaders/polars_dataframe.py b/libs/langchain/langchain/document_loaders/polars_dataframe.py new file mode 100644 index 00000000000..6ece942df49 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/polars_dataframe.py @@ -0,0 +1,32 @@ +from typing import Any, Iterator + +from langchain.docstore.document import Document +from langchain.document_loaders.dataframe import BaseDataFrameLoader + + +class PolarsDataFrameLoader(BaseDataFrameLoader): + """Load `Polars` DataFrame.""" + + def __init__(self, data_frame: Any, *, page_content_column: str = "text"): + """Initialize with dataframe object. + + Args: + data_frame: Polars DataFrame object. + page_content_column: Name of the column containing the page content. + Defaults to "text". + """ + import polars as pl + + if not isinstance(data_frame, pl.DataFrame): + raise ValueError( + f"Expected data_frame to be a pl.DataFrame, got {type(data_frame)}" + ) + super().__init__(data_frame, page_content_column=page_content_column) + + def lazy_load(self) -> Iterator[Document]: + """Lazy load records from dataframe.""" + + for row in self.data_frame.iter_rows(named=True): + text = row[self.page_content_column] + row.pop(self.page_content_column) + yield Document(page_content=text, metadata=row) diff --git a/libs/langchain/langchain/document_loaders/xorbits.py b/libs/langchain/langchain/document_loaders/xorbits.py index bcc4e680f68..723e9dc1b59 100644 --- a/libs/langchain/langchain/document_loaders/xorbits.py +++ b/libs/langchain/langchain/document_loaders/xorbits.py @@ -1,10 +1,9 @@ -from typing import Any, Iterator, List +from typing import Any -from langchain.docstore.document import Document -from langchain.document_loaders.base import BaseLoader +from langchain.document_loaders.dataframe import BaseDataFrameLoader -class XorbitsLoader(BaseLoader): +class XorbitsLoader(BaseDataFrameLoader): """Load `Xorbits` DataFrame.""" def __init__(self, data_frame: Any, page_content_column: str = "text"): @@ -30,17 +29,4 @@ class XorbitsLoader(BaseLoader): f"Expected data_frame to be a xorbits.pandas.DataFrame, \ got {type(data_frame)}" ) - self.data_frame = data_frame - self.page_content_column = page_content_column - - def lazy_load(self) -> Iterator[Document]: - """Lazy load records from dataframe.""" - for _, row in self.data_frame.iterrows(): - text = row[self.page_content_column] - metadata = row.to_dict() - metadata.pop(self.page_content_column) - yield Document(page_content=text, metadata=metadata) - - def load(self) -> List[Document]: - """Load full dataframe.""" - return list(self.lazy_load()) + super().__init__(data_frame, page_content_column=page_content_column) diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py new file mode 100644 index 00000000000..bd8e129dafa --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py @@ -0,0 +1,48 @@ +import polars as pl +import pytest + +from langchain.document_loaders import PolarsDataFrameLoader +from langchain.schema import Document + + +@pytest.fixture +def sample_data_frame() -> pl.DataFrame: + data = { + "text": ["Hello", "World"], + "author": ["Alice", "Bob"], + "date": ["2022-01-01", "2022-01-02"], + } + return pl.DataFrame(data) + + +def test_load_returns_list_of_documents(sample_data_frame: pl.DataFrame) -> None: + loader = PolarsDataFrameLoader(sample_data_frame) + docs = loader.load() + assert isinstance(docs, list) + assert all(isinstance(doc, Document) for doc in docs) + assert len(docs) == 2 + + +def test_load_converts_dataframe_columns_to_document_metadata( + sample_data_frame: pl.DataFrame, +) -> None: + loader = PolarsDataFrameLoader(sample_data_frame) + docs = loader.load() + + for i, doc in enumerate(docs): + df: pl.DataFrame = sample_data_frame[i] + assert df is not None + assert doc.metadata["author"] == df.select("author").item() + assert doc.metadata["date"] == df.select("date").item() + + +def test_load_uses_page_content_column_to_create_document_text( + sample_data_frame: pl.DataFrame, +) -> None: + sample_data_frame = sample_data_frame.rename(mapping={"text": "dummy_test_column"}) + loader = PolarsDataFrameLoader( + sample_data_frame, page_content_column="dummy_test_column" + ) + docs = loader.load() + assert docs[0].page_content == "Hello" + assert docs[1].page_content == "World"