mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 12:48:12 +00:00
Harrison/spark reader (#5405)
Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
8259f9b7fa
commit
760632b292
@ -130,6 +130,7 @@ We need access tokens and sometime other parameters to get access to these datas
|
||||
./document_loaders/examples/notion.ipynb
|
||||
./document_loaders/examples/obsidian.ipynb
|
||||
./document_loaders/examples/psychic.ipynb
|
||||
./document_loaders/examples/pyspark_dataframe.ipynb
|
||||
./document_loaders/examples/readthedocs_documentation.ipynb
|
||||
./document_loaders/examples/reddit.ipynb
|
||||
./document_loaders/examples/roam.ipynb
|
||||
|
@ -0,0 +1,97 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PySpack DataFrame Loader\n",
|
||||
"\n",
|
||||
"This shows how to load data from a PySpark DataFrame"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install pyspark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pyspark.sql import SparkSession"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"spark = SparkSession.builder.getOrCreate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = spark.read.csv('example_data/mlb_teams_2012.csv', header=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import PySparkDataFrameLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = PySparkDataFrameLoader(spark, df, page_content_column=\"Team\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader.load()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -74,6 +74,7 @@ from langchain.document_loaders.pdf import (
|
||||
)
|
||||
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
|
||||
from langchain.document_loaders.psychic import PsychicLoader
|
||||
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||
from langchain.document_loaders.python import PythonLoader
|
||||
from langchain.document_loaders.readthedocs import ReadTheDocsLoader
|
||||
from langchain.document_loaders.reddit import RedditPostsLoader
|
||||
@ -188,6 +189,7 @@ __all__ = [
|
||||
"PyPDFDirectoryLoader",
|
||||
"PyPDFLoader",
|
||||
"PyPDFium2Loader",
|
||||
"PySparkDataFrameLoader",
|
||||
"PythonLoader",
|
||||
"ReadTheDocsLoader",
|
||||
"RedditPostsLoader",
|
||||
|
80
langchain/document_loaders/pyspark_dataframe.py
Normal file
80
langchain/document_loaders/pyspark_dataframe.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""Load from a Spark Dataframe object"""
|
||||
import itertools
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
|
||||
class PySparkDataFrameLoader(BaseLoader):
|
||||
"""Load PySpark DataFrames"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spark_session: Optional["SparkSession"] = None,
|
||||
df: Optional[Any] = None,
|
||||
page_content_column: str = "text",
|
||||
fraction_of_memory: float = 0.1,
|
||||
):
|
||||
"""Initialize with a Spark DataFrame object."""
|
||||
try:
|
||||
from pyspark.sql import DataFrame, SparkSession
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"pyspark is not installed. "
|
||||
"Please install it with `pip install pyspark`"
|
||||
)
|
||||
|
||||
self.spark = (
|
||||
spark_session if spark_session else SparkSession.builder.getOrCreate()
|
||||
)
|
||||
|
||||
if not isinstance(df, DataFrame):
|
||||
raise ValueError(
|
||||
f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
|
||||
)
|
||||
self.df = df
|
||||
self.page_content_column = page_content_column
|
||||
self.fraction_of_memory = fraction_of_memory
|
||||
self.num_rows, self.max_num_rows = self.get_num_rows()
|
||||
self.rdd_df = self.df.rdd.map(list)
|
||||
self.column_names = self.df.columns
|
||||
|
||||
def get_num_rows(self) -> Tuple[int, int]:
|
||||
"""Gets the amount of "feasible" rows for the DataFrame"""
|
||||
row = self.df.limit(1).collect()[0]
|
||||
estimated_row_size = sys.getsizeof(row)
|
||||
mem_info = psutil.virtual_memory()
|
||||
available_memory = mem_info.available
|
||||
max_num_rows = int(
|
||||
(available_memory / estimated_row_size) * self.fraction_of_memory
|
||||
)
|
||||
return min(max_num_rows, self.df.count()), max_num_rows
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""A lazy loader for document content."""
|
||||
for row in self.rdd_df.toLocalIterator():
|
||||
metadata = {self.column_names[i]: row[i] for i in range(len(row))}
|
||||
text = metadata[self.page_content_column]
|
||||
metadata.pop(self.page_content_column)
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from the dataframe."""
|
||||
if self.df.count() > self.max_num_rows:
|
||||
logger.warning(
|
||||
f"The number of DataFrame rows is {self.df.count()}, "
|
||||
f"but we will only include the amount "
|
||||
f"of rows that can reasonably fit in memory: {self.num_rows}."
|
||||
)
|
||||
lazy_load_iterator = self.lazy_load()
|
||||
return list(itertools.islice(lazy_load_iterator, self.num_rows))
|
37
poetry.lock
generated
37
poetry.lock
generated
@ -6643,6 +6643,18 @@ pytz = "*"
|
||||
requests = "*"
|
||||
requests-oauthlib = ">=0.4.1"
|
||||
|
||||
[[package]]
|
||||
name = "py4j"
|
||||
version = "0.10.9.7"
|
||||
description = "Enables Python programs to dynamically access arbitrary Java objects"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"},
|
||||
{file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyaes"
|
||||
version = "1.6.1"
|
||||
@ -7229,6 +7241,27 @@ files = [
|
||||
{file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyspark"
|
||||
version = "3.4.0"
|
||||
description = "Apache Spark Python API"
|
||||
category = "main"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pyspark-3.4.0.tar.gz", hash = "sha256:167a23e11854adb37f8602de6fcc3a4f96fd5f1e323b9bb83325f38408c5aafd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
py4j = "0.10.9.7"
|
||||
|
||||
[package.extras]
|
||||
connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.48.1)", "grpcio-status (>=1.48.1)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||
ml = ["numpy (>=1.15)"]
|
||||
mllib = ["numpy (>=1.15)"]
|
||||
pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||
sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytesseract"
|
||||
version = "0.3.10"
|
||||
@ -10920,7 +10953,7 @@ azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
|
||||
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
|
||||
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
qdrant = ["qdrant-client"]
|
||||
@ -10929,4 +10962,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "1033e47cdab7d3a15fb9322bad64609f77fd3befc47c1a01dc91b22cbbc708a3"
|
||||
content-hash = "b3dc23f376de141d22b729d038144a1e6d66983a910160c3500fe0d79f8e5917"
|
||||
|
@ -100,6 +100,7 @@ azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
||||
py-trello = {version = "^0.19.0", optional = true}
|
||||
momento = {version = "^1.5.0", optional = true}
|
||||
bibtexparser = {version = "^1.4.0", optional = true}
|
||||
pyspark = {version = "^3.4.0", optional = true}
|
||||
|
||||
[tool.poetry.group.docs.dependencies]
|
||||
autodoc_pydantic = "^1.8.0"
|
||||
@ -301,6 +302,7 @@ extended_testing = [
|
||||
"html2text",
|
||||
"py-trello",
|
||||
"scikit-learn",
|
||||
"pyspark",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -0,0 +1,38 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||
|
||||
|
||||
def test_pyspark_loader_load_valid_data() -> None:
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
# Requires a session to be set up
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
data = [
|
||||
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
|
||||
]
|
||||
df = spark.createDataFrame(data, ["text", "label"])
|
||||
|
||||
expected_docs = [
|
||||
Document(
|
||||
page_content=data[0][0],
|
||||
metadata={"label": data[0][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[1][0],
|
||||
metadata={"label": data[1][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[2][0],
|
||||
metadata={"label": data[2][1]},
|
||||
),
|
||||
]
|
||||
|
||||
loader = PySparkDataFrameLoader(
|
||||
spark_session=spark, df=df, page_content_column="text"
|
||||
)
|
||||
result = loader.load()
|
||||
|
||||
assert result == expected_docs
|
Loading…
Reference in New Issue
Block a user