mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
Harrison/spark reader (#5405)
Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
8259f9b7fa
commit
760632b292
@ -130,6 +130,7 @@ We need access tokens and sometime other parameters to get access to these datas
|
|||||||
./document_loaders/examples/notion.ipynb
|
./document_loaders/examples/notion.ipynb
|
||||||
./document_loaders/examples/obsidian.ipynb
|
./document_loaders/examples/obsidian.ipynb
|
||||||
./document_loaders/examples/psychic.ipynb
|
./document_loaders/examples/psychic.ipynb
|
||||||
|
./document_loaders/examples/pyspark_dataframe.ipynb
|
||||||
./document_loaders/examples/readthedocs_documentation.ipynb
|
./document_loaders/examples/readthedocs_documentation.ipynb
|
||||||
./document_loaders/examples/reddit.ipynb
|
./document_loaders/examples/reddit.ipynb
|
||||||
./document_loaders/examples/roam.ipynb
|
./document_loaders/examples/roam.ipynb
|
||||||
|
@ -0,0 +1,97 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# PySpack DataFrame Loader\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to load data from a PySpark DataFrame"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#!pip install pyspark"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pyspark.sql import SparkSession"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"spark = SparkSession.builder.getOrCreate()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df = spark.read.csv('example_data/mlb_teams_2012.csv', header=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.document_loaders import PySparkDataFrameLoader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"loader = PySparkDataFrameLoader(spark, df, page_content_column=\"Team\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"loader.load()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -74,6 +74,7 @@ from langchain.document_loaders.pdf import (
|
|||||||
)
|
)
|
||||||
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
|
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
|
||||||
from langchain.document_loaders.psychic import PsychicLoader
|
from langchain.document_loaders.psychic import PsychicLoader
|
||||||
|
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||||
from langchain.document_loaders.python import PythonLoader
|
from langchain.document_loaders.python import PythonLoader
|
||||||
from langchain.document_loaders.readthedocs import ReadTheDocsLoader
|
from langchain.document_loaders.readthedocs import ReadTheDocsLoader
|
||||||
from langchain.document_loaders.reddit import RedditPostsLoader
|
from langchain.document_loaders.reddit import RedditPostsLoader
|
||||||
@ -188,6 +189,7 @@ __all__ = [
|
|||||||
"PyPDFDirectoryLoader",
|
"PyPDFDirectoryLoader",
|
||||||
"PyPDFLoader",
|
"PyPDFLoader",
|
||||||
"PyPDFium2Loader",
|
"PyPDFium2Loader",
|
||||||
|
"PySparkDataFrameLoader",
|
||||||
"PythonLoader",
|
"PythonLoader",
|
||||||
"ReadTheDocsLoader",
|
"ReadTheDocsLoader",
|
||||||
"RedditPostsLoader",
|
"RedditPostsLoader",
|
||||||
|
80
langchain/document_loaders/pyspark_dataframe.py
Normal file
80
langchain/document_loaders/pyspark_dataframe.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
"""Load from a Spark Dataframe object"""
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
|
||||||
|
|
||||||
|
class PySparkDataFrameLoader(BaseLoader):
|
||||||
|
"""Load PySpark DataFrames"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spark_session: Optional["SparkSession"] = None,
|
||||||
|
df: Optional[Any] = None,
|
||||||
|
page_content_column: str = "text",
|
||||||
|
fraction_of_memory: float = 0.1,
|
||||||
|
):
|
||||||
|
"""Initialize with a Spark DataFrame object."""
|
||||||
|
try:
|
||||||
|
from pyspark.sql import DataFrame, SparkSession
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"pyspark is not installed. "
|
||||||
|
"Please install it with `pip install pyspark`"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spark = (
|
||||||
|
spark_session if spark_session else SparkSession.builder.getOrCreate()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(df, DataFrame):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
|
||||||
|
)
|
||||||
|
self.df = df
|
||||||
|
self.page_content_column = page_content_column
|
||||||
|
self.fraction_of_memory = fraction_of_memory
|
||||||
|
self.num_rows, self.max_num_rows = self.get_num_rows()
|
||||||
|
self.rdd_df = self.df.rdd.map(list)
|
||||||
|
self.column_names = self.df.columns
|
||||||
|
|
||||||
|
def get_num_rows(self) -> Tuple[int, int]:
|
||||||
|
"""Gets the amount of "feasible" rows for the DataFrame"""
|
||||||
|
row = self.df.limit(1).collect()[0]
|
||||||
|
estimated_row_size = sys.getsizeof(row)
|
||||||
|
mem_info = psutil.virtual_memory()
|
||||||
|
available_memory = mem_info.available
|
||||||
|
max_num_rows = int(
|
||||||
|
(available_memory / estimated_row_size) * self.fraction_of_memory
|
||||||
|
)
|
||||||
|
return min(max_num_rows, self.df.count()), max_num_rows
|
||||||
|
|
||||||
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
|
"""A lazy loader for document content."""
|
||||||
|
for row in self.rdd_df.toLocalIterator():
|
||||||
|
metadata = {self.column_names[i]: row[i] for i in range(len(row))}
|
||||||
|
text = metadata[self.page_content_column]
|
||||||
|
metadata.pop(self.page_content_column)
|
||||||
|
yield Document(page_content=text, metadata=metadata)
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load from the dataframe."""
|
||||||
|
if self.df.count() > self.max_num_rows:
|
||||||
|
logger.warning(
|
||||||
|
f"The number of DataFrame rows is {self.df.count()}, "
|
||||||
|
f"but we will only include the amount "
|
||||||
|
f"of rows that can reasonably fit in memory: {self.num_rows}."
|
||||||
|
)
|
||||||
|
lazy_load_iterator = self.lazy_load()
|
||||||
|
return list(itertools.islice(lazy_load_iterator, self.num_rows))
|
37
poetry.lock
generated
37
poetry.lock
generated
@ -6643,6 +6643,18 @@ pytz = "*"
|
|||||||
requests = "*"
|
requests = "*"
|
||||||
requests-oauthlib = ">=0.4.1"
|
requests-oauthlib = ">=0.4.1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "py4j"
|
||||||
|
version = "0.10.9.7"
|
||||||
|
description = "Enables Python programs to dynamically access arbitrary Java objects"
|
||||||
|
category = "main"
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"},
|
||||||
|
{file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyaes"
|
name = "pyaes"
|
||||||
version = "1.6.1"
|
version = "1.6.1"
|
||||||
@ -7229,6 +7241,27 @@ files = [
|
|||||||
{file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"},
|
{file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyspark"
|
||||||
|
version = "3.4.0"
|
||||||
|
description = "Apache Spark Python API"
|
||||||
|
category = "main"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pyspark-3.4.0.tar.gz", hash = "sha256:167a23e11854adb37f8602de6fcc3a4f96fd5f1e323b9bb83325f38408c5aafd"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
py4j = "0.10.9.7"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.48.1)", "grpcio-status (>=1.48.1)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||||
|
ml = ["numpy (>=1.15)"]
|
||||||
|
mllib = ["numpy (>=1.15)"]
|
||||||
|
pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||||
|
sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytesseract"
|
name = "pytesseract"
|
||||||
version = "0.3.10"
|
version = "0.3.10"
|
||||||
@ -10920,7 +10953,7 @@ azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-
|
|||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
docarray = ["docarray"]
|
docarray = ["docarray"]
|
||||||
embeddings = ["sentence-transformers"]
|
embeddings = ["sentence-transformers"]
|
||||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
|
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
|
||||||
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||||
openai = ["openai", "tiktoken"]
|
openai = ["openai", "tiktoken"]
|
||||||
qdrant = ["qdrant-client"]
|
qdrant = ["qdrant-client"]
|
||||||
@ -10929,4 +10962,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "1033e47cdab7d3a15fb9322bad64609f77fd3befc47c1a01dc91b22cbbc708a3"
|
content-hash = "b3dc23f376de141d22b729d038144a1e6d66983a910160c3500fe0d79f8e5917"
|
||||||
|
@ -100,6 +100,7 @@ azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
|||||||
py-trello = {version = "^0.19.0", optional = true}
|
py-trello = {version = "^0.19.0", optional = true}
|
||||||
momento = {version = "^1.5.0", optional = true}
|
momento = {version = "^1.5.0", optional = true}
|
||||||
bibtexparser = {version = "^1.4.0", optional = true}
|
bibtexparser = {version = "^1.4.0", optional = true}
|
||||||
|
pyspark = {version = "^3.4.0", optional = true}
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
autodoc_pydantic = "^1.8.0"
|
autodoc_pydantic = "^1.8.0"
|
||||||
@ -301,6 +302,7 @@ extended_testing = [
|
|||||||
"html2text",
|
"html2text",
|
||||||
"py-trello",
|
"py-trello",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
"pyspark",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||||
|
|
||||||
|
|
||||||
|
def test_pyspark_loader_load_valid_data() -> None:
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
|
||||||
|
# Requires a session to be set up
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
data = [
|
||||||
|
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
|
||||||
|
]
|
||||||
|
df = spark.createDataFrame(data, ["text", "label"])
|
||||||
|
|
||||||
|
expected_docs = [
|
||||||
|
Document(
|
||||||
|
page_content=data[0][0],
|
||||||
|
metadata={"label": data[0][1]},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content=data[1][0],
|
||||||
|
metadata={"label": data[1][1]},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content=data[2][0],
|
||||||
|
metadata={"label": data[2][1]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
loader = PySparkDataFrameLoader(
|
||||||
|
spark_session=spark, df=df, page_content_column="text"
|
||||||
|
)
|
||||||
|
result = loader.load()
|
||||||
|
|
||||||
|
assert result == expected_docs
|
Loading…
Reference in New Issue
Block a user