mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
Harrison/spark reader (#5405)
Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||
|
||||
|
||||
def test_pyspark_loader_load_valid_data() -> None:
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
# Requires a session to be set up
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
data = [
|
||||
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
|
||||
]
|
||||
df = spark.createDataFrame(data, ["text", "label"])
|
||||
|
||||
expected_docs = [
|
||||
Document(
|
||||
page_content=data[0][0],
|
||||
metadata={"label": data[0][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[1][0],
|
||||
metadata={"label": data[1][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[2][0],
|
||||
metadata={"label": data[2][1]},
|
||||
),
|
||||
]
|
||||
|
||||
loader = PySparkDataFrameLoader(
|
||||
spark_session=spark, df=df, page_content_column="text"
|
||||
)
|
||||
result = loader.load()
|
||||
|
||||
assert result == expected_docs
|
Reference in New Issue
Block a user