Harrison/spark reader (#5405)

Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
Harrison Chase
2023-05-29 20:23:17 -07:00
committed by GitHub
parent 8259f9b7fa
commit 760632b292
7 changed files with 255 additions and 2 deletions

View File

@@ -0,0 +1,38 @@
import random
import string
from langchain.docstore.document import Document
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
def test_pyspark_loader_load_valid_data() -> None:
from pyspark.sql import SparkSession
# Requires a session to be set up
spark = SparkSession.builder.getOrCreate()
data = [
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
]
df = spark.createDataFrame(data, ["text", "label"])
expected_docs = [
Document(
page_content=data[0][0],
metadata={"label": data[0][1]},
),
Document(
page_content=data[1][0],
metadata={"label": data[1][1]},
),
Document(
page_content=data[2][0],
metadata={"label": data[2][1]},
),
]
loader = PySparkDataFrameLoader(
spark_session=spark, df=df, page_content_column="text"
)
result = loader.load()
assert result == expected_docs