diff --git a/docs/extras/integrations/document_loaders/polars_dataframe.ipynb b/docs/extras/integrations/document_loaders/polars_dataframe.ipynb
new file mode 100644
index 00000000000..52936f16540
--- /dev/null
+++ b/docs/extras/integrations/document_loaders/polars_dataframe.ipynb
@@ -0,0 +1,225 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "213a38a2",
+ "metadata": {},
+ "source": [
+ "# Polars DataFrame\n",
+ "\n",
+ "This notebook goes over how to load data from a [polars](https://pola-rs.github.io/polars-book/user-guide/) DataFrame."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "f6a7a9e4-80d6-486a-b2e3-636c568aa97c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#!pip install polars"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "79331964",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import polars as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "e487044c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = pl.read_csv(\"example_data/mlb_teams_2012.csv\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "ac273ca1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "
shape: (5, 3)Team | "Payroll (millions)" | "Wins" |
---|
str | f64 | i64 |
"Nationals" | 81.34 | 98 |
"Reds" | 82.2 | 97 |
"Yankees" | 197.96 | 95 |
"Giants" | 117.62 | 94 |
"Braves" | 83.31 | 94 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌───────────┬───────────────────────┬─────────┐\n",
+ "│ Team ┆ \"Payroll (millions)\" ┆ \"Wins\" │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ str ┆ f64 ┆ i64 │\n",
+ "╞═══════════╪═══════════════════════╪═════════╡\n",
+ "│ Nationals ┆ 81.34 ┆ 98 │\n",
+ "│ Reds ┆ 82.2 ┆ 97 │\n",
+ "│ Yankees ┆ 197.96 ┆ 95 │\n",
+ "│ Giants ┆ 117.62 ┆ 94 │\n",
+ "│ Braves ┆ 83.31 ┆ 94 │\n",
+ "└───────────┴───────────────────────┴─────────┘"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "66e47a13",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain.document_loaders import PolarsDataFrameLoader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "2334caca",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loader = PolarsDataFrameLoader(df, page_content_column=\"Team\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "d616c2b0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Document(page_content='Nationals', metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}),\n",
+ " Document(page_content='Reds', metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}),\n",
+ " Document(page_content='Yankees', metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}),\n",
+ " Document(page_content='Giants', metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}),\n",
+ " Document(page_content='Braves', metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}),\n",
+ " Document(page_content='Athletics', metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}),\n",
+ " Document(page_content='Rangers', metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}),\n",
+ " Document(page_content='Orioles', metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}),\n",
+ " Document(page_content='Rays', metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}),\n",
+ " Document(page_content='Angels', metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}),\n",
+ " Document(page_content='Tigers', metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}),\n",
+ " Document(page_content='Cardinals', metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}),\n",
+ " Document(page_content='Dodgers', metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}),\n",
+ " Document(page_content='White Sox', metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}),\n",
+ " Document(page_content='Brewers', metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}),\n",
+ " Document(page_content='Phillies', metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}),\n",
+ " Document(page_content='Diamondbacks', metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}),\n",
+ " Document(page_content='Pirates', metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}),\n",
+ " Document(page_content='Padres', metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}),\n",
+ " Document(page_content='Mariners', metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}),\n",
+ " Document(page_content='Mets', metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}),\n",
+ " Document(page_content='Blue Jays', metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}),\n",
+ " Document(page_content='Royals', metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}),\n",
+ " Document(page_content='Marlins', metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}),\n",
+ " Document(page_content='Red Sox', metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}),\n",
+ " Document(page_content='Indians', metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}),\n",
+ " Document(page_content='Twins', metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}),\n",
+ " Document(page_content='Rockies', metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}),\n",
+ " Document(page_content='Cubs', metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}),\n",
+ " Document(page_content='Astros', metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55})]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "loader.load()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "beb55c2f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "page_content='Nationals' metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}\n",
+ "page_content='Reds' metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}\n",
+ "page_content='Yankees' metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}\n",
+ "page_content='Giants' metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}\n",
+ "page_content='Braves' metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}\n",
+ "page_content='Athletics' metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}\n",
+ "page_content='Rangers' metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}\n",
+ "page_content='Orioles' metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}\n",
+ "page_content='Rays' metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}\n",
+ "page_content='Angels' metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}\n",
+ "page_content='Tigers' metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}\n",
+ "page_content='Cardinals' metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}\n",
+ "page_content='Dodgers' metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}\n",
+ "page_content='White Sox' metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}\n",
+ "page_content='Brewers' metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}\n",
+ "page_content='Phillies' metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}\n",
+ "page_content='Diamondbacks' metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}\n",
+ "page_content='Pirates' metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}\n",
+ "page_content='Padres' metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}\n",
+ "page_content='Mariners' metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}\n",
+ "page_content='Mets' metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}\n",
+ "page_content='Blue Jays' metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}\n",
+ "page_content='Royals' metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}\n",
+ "page_content='Marlins' metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}\n",
+ "page_content='Red Sox' metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}\n",
+ "page_content='Indians' metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}\n",
+ "page_content='Twins' metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}\n",
+ "page_content='Rockies' metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}\n",
+ "page_content='Cubs' metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}\n",
+ "page_content='Astros' metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Use lazy load for larger table, which won't read the full table into memory\n",
+ "for i in loader.lazy_load():\n",
+ " print(i)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py
index 195b586ac2a..30f69659cee 100644
--- a/libs/langchain/langchain/document_loaders/__init__.py
+++ b/libs/langchain/langchain/document_loaders/__init__.py
@@ -132,6 +132,7 @@ from langchain.document_loaders.pdf import (
PyPDFLoader,
UnstructuredPDFLoader,
)
+from langchain.document_loaders.polars_dataframe import PolarsDataFrameLoader
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
from langchain.document_loaders.psychic import PsychicLoader
from langchain.document_loaders.pubmed import PubMedLoader
@@ -299,6 +300,7 @@ __all__ = [
"PDFPlumberLoader",
"PagedPDFSplitter",
"PlaywrightURLLoader",
+ "PolarsDataFrameLoader",
"PsychicLoader",
"PubMedLoader",
"PyMuPDFLoader",
diff --git a/libs/langchain/langchain/document_loaders/dataframe.py b/libs/langchain/langchain/document_loaders/dataframe.py
index 0476f6a2986..261426a3ceb 100644
--- a/libs/langchain/langchain/document_loaders/dataframe.py
+++ b/libs/langchain/langchain/document_loaders/dataframe.py
@@ -4,23 +4,15 @@ from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
-class DataFrameLoader(BaseLoader):
- """Load `Pandas` DataFrame."""
-
- def __init__(self, data_frame: Any, page_content_column: str = "text"):
+class BaseDataFrameLoader(BaseLoader):
+ def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
"""Initialize with dataframe object.
Args:
- data_frame: Pandas DataFrame object.
+ data_frame: DataFrame object.
page_content_column: Name of the column containing the page content.
Defaults to "text".
"""
- import pandas as pd
-
- if not isinstance(data_frame, pd.DataFrame):
- raise ValueError(
- f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}"
- )
self.data_frame = data_frame
self.page_content_column = page_content_column
@@ -36,3 +28,28 @@ class DataFrameLoader(BaseLoader):
def load(self) -> List[Document]:
"""Load full dataframe."""
return list(self.lazy_load())
+
+
+class DataFrameLoader(BaseDataFrameLoader):
+ """Load `Pandas` DataFrame."""
+
+ def __init__(self, data_frame: Any, page_content_column: str = "text"):
+ """Initialize with dataframe object.
+
+ Args:
+ data_frame: Pandas DataFrame object.
+ page_content_column: Name of the column containing the page content.
+ Defaults to "text".
+ """
+ try:
+ import pandas as pd
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import pandas, please install with `pip install pandas`."
+ ) from e
+
+ if not isinstance(data_frame, pd.DataFrame):
+ raise ValueError(
+ f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}"
+ )
+ super().__init__(data_frame, page_content_column=page_content_column)
diff --git a/libs/langchain/langchain/document_loaders/polars_dataframe.py b/libs/langchain/langchain/document_loaders/polars_dataframe.py
new file mode 100644
index 00000000000..6ece942df49
--- /dev/null
+++ b/libs/langchain/langchain/document_loaders/polars_dataframe.py
@@ -0,0 +1,32 @@
+from typing import Any, Iterator
+
+from langchain.docstore.document import Document
+from langchain.document_loaders.dataframe import BaseDataFrameLoader
+
+
+class PolarsDataFrameLoader(BaseDataFrameLoader):
+ """Load `Polars` DataFrame."""
+
+ def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
+ """Initialize with dataframe object.
+
+ Args:
+ data_frame: Polars DataFrame object.
+ page_content_column: Name of the column containing the page content.
+ Defaults to "text".
+ """
+ import polars as pl
+
+ if not isinstance(data_frame, pl.DataFrame):
+ raise ValueError(
+ f"Expected data_frame to be a pl.DataFrame, got {type(data_frame)}"
+ )
+ super().__init__(data_frame, page_content_column=page_content_column)
+
+ def lazy_load(self) -> Iterator[Document]:
+ """Lazy load records from dataframe."""
+
+ for row in self.data_frame.iter_rows(named=True):
+ text = row[self.page_content_column]
+ row.pop(self.page_content_column)
+ yield Document(page_content=text, metadata=row)
diff --git a/libs/langchain/langchain/document_loaders/xorbits.py b/libs/langchain/langchain/document_loaders/xorbits.py
index bcc4e680f68..723e9dc1b59 100644
--- a/libs/langchain/langchain/document_loaders/xorbits.py
+++ b/libs/langchain/langchain/document_loaders/xorbits.py
@@ -1,10 +1,9 @@
-from typing import Any, Iterator, List
+from typing import Any
-from langchain.docstore.document import Document
-from langchain.document_loaders.base import BaseLoader
+from langchain.document_loaders.dataframe import BaseDataFrameLoader
-class XorbitsLoader(BaseLoader):
+class XorbitsLoader(BaseDataFrameLoader):
"""Load `Xorbits` DataFrame."""
def __init__(self, data_frame: Any, page_content_column: str = "text"):
@@ -30,17 +29,4 @@ class XorbitsLoader(BaseLoader):
f"Expected data_frame to be a xorbits.pandas.DataFrame, \
got {type(data_frame)}"
)
- self.data_frame = data_frame
- self.page_content_column = page_content_column
-
- def lazy_load(self) -> Iterator[Document]:
- """Lazy load records from dataframe."""
- for _, row in self.data_frame.iterrows():
- text = row[self.page_content_column]
- metadata = row.to_dict()
- metadata.pop(self.page_content_column)
- yield Document(page_content=text, metadata=metadata)
-
- def load(self) -> List[Document]:
- """Load full dataframe."""
- return list(self.lazy_load())
+ super().__init__(data_frame, page_content_column=page_content_column)
diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py
new file mode 100644
index 00000000000..bd8e129dafa
--- /dev/null
+++ b/libs/langchain/tests/integration_tests/document_loaders/test_polars_dataframe.py
@@ -0,0 +1,48 @@
+import polars as pl
+import pytest
+
+from langchain.document_loaders import PolarsDataFrameLoader
+from langchain.schema import Document
+
+
+@pytest.fixture
+def sample_data_frame() -> pl.DataFrame:
+ data = {
+ "text": ["Hello", "World"],
+ "author": ["Alice", "Bob"],
+ "date": ["2022-01-01", "2022-01-02"],
+ }
+ return pl.DataFrame(data)
+
+
+def test_load_returns_list_of_documents(sample_data_frame: pl.DataFrame) -> None:
+ loader = PolarsDataFrameLoader(sample_data_frame)
+ docs = loader.load()
+ assert isinstance(docs, list)
+ assert all(isinstance(doc, Document) for doc in docs)
+ assert len(docs) == 2
+
+
+def test_load_converts_dataframe_columns_to_document_metadata(
+ sample_data_frame: pl.DataFrame,
+) -> None:
+ loader = PolarsDataFrameLoader(sample_data_frame)
+ docs = loader.load()
+
+ for i, doc in enumerate(docs):
+ df: pl.DataFrame = sample_data_frame[i]
+ assert df is not None
+ assert doc.metadata["author"] == df.select("author").item()
+ assert doc.metadata["date"] == df.select("date").item()
+
+
+def test_load_uses_page_content_column_to_create_document_text(
+ sample_data_frame: pl.DataFrame,
+) -> None:
+ sample_data_frame = sample_data_frame.rename(mapping={"text": "dummy_test_column"})
+ loader = PolarsDataFrameLoader(
+ sample_data_frame, page_content_column="dummy_test_column"
+ )
+ docs = loader.load()
+ assert docs[0].page_content == "Hello"
+ assert docs[1].page_content == "World"