[Fix] Fix Cassandra Document loader default page content mapper (#16273)

We can't use `json.dumps` by default as many types returned by the
cassandra driver are not serializable. It's safer to use `str` and let
users define their own custom `page_content_mapper` if needed.
This commit is contained in:
Christophe Bornet
2024-01-27 20:23:02 +01:00
committed by GitHub
parent e86fd946c8
commit 4915c3cd86
4 changed files with 11 additions and 15 deletions

View File

@@ -59,6 +59,7 @@ from langchain_community.document_loaders.blob_loaders import (
from langchain_community.document_loaders.blockchain import BlockchainDocumentLoader from langchain_community.document_loaders.blockchain import BlockchainDocumentLoader
from langchain_community.document_loaders.brave_search import BraveSearchLoader from langchain_community.document_loaders.brave_search import BraveSearchLoader
from langchain_community.document_loaders.browserless import BrowserlessLoader from langchain_community.document_loaders.browserless import BrowserlessLoader
from langchain_community.document_loaders.cassandra import CassandraLoader
from langchain_community.document_loaders.chatgpt import ChatGPTLoader from langchain_community.document_loaders.chatgpt import ChatGPTLoader
from langchain_community.document_loaders.chromium import AsyncChromiumLoader from langchain_community.document_loaders.chromium import AsyncChromiumLoader
from langchain_community.document_loaders.college_confidential import ( from langchain_community.document_loaders.college_confidential import (
@@ -267,6 +268,7 @@ __all__ = [
"BlockchainDocumentLoader", "BlockchainDocumentLoader",
"BraveSearchLoader", "BraveSearchLoader",
"BrowserlessLoader", "BrowserlessLoader",
"CassandraLoader",
"CSVLoader", "CSVLoader",
"ChatGPTLoader", "ChatGPTLoader",
"CoNLLULoader", "CoNLLULoader",

View File

@@ -1,4 +1,3 @@
import json
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@@ -14,13 +13,6 @@ from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader from langchain_community.document_loaders.base import BaseLoader
def default_page_content_mapper(row: Any) -> str:
if hasattr(row, "_asdict"):
return json.dumps(row._asdict())
return json.dumps(row)
_NOT_SET = object() _NOT_SET = object()
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -36,7 +28,7 @@ class CassandraLoader(BaseLoader):
session: Optional["Session"] = None, session: Optional["Session"] = None,
keyspace: Optional[str] = None, keyspace: Optional[str] = None,
query: Optional[Union[str, "Statement"]] = None, query: Optional[Union[str, "Statement"]] = None,
page_content_mapper: Callable[[Any], str] = default_page_content_mapper, page_content_mapper: Callable[[Any], str] = str,
metadata_mapper: Callable[[Any], dict] = lambda _: {}, metadata_mapper: Callable[[Any], dict] = lambda _: {},
*, *,
query_parameters: Union[dict, Sequence] = None, query_parameters: Union[dict, Sequence] = None,
@@ -61,6 +53,7 @@ class CassandraLoader(BaseLoader):
query: The query used to load the data. query: The query used to load the data.
(do not use together with the table parameter) (do not use together with the table parameter)
page_content_mapper: a function to convert a row to string page content. page_content_mapper: a function to convert a row to string page content.
Defaults to the str representation of the row.
query_parameters: The query parameters used when calling session.execute . query_parameters: The query parameters used when calling session.execute .
query_timeout: The query timeout used when calling session.execute . query_timeout: The query timeout used when calling session.execute .
query_custom_payload: The query custom_payload used when calling query_custom_payload: The query custom_payload used when calling

View File

@@ -59,11 +59,11 @@ def test_loader_table(keyspace: str) -> None:
loader = CassandraLoader(table=CASSANDRA_TABLE) loader = CassandraLoader(table=CASSANDRA_TABLE)
assert loader.load() == [ assert loader.load() == [
Document( Document(
page_content='{"row_id": "id1", "body_blob": "text1"}', page_content="Row(row_id='id1', body_blob='text1')",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
), ),
Document( Document(
page_content='{"row_id": "id2", "body_blob": "text2"}', page_content="Row(row_id='id2', body_blob='text2')",
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace}, metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
), ),
] ]
@@ -74,8 +74,8 @@ def test_loader_query(keyspace: str) -> None:
query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}" query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}"
) )
assert loader.load() == [ assert loader.load() == [
Document(page_content='{"body_blob": "text1"}'), Document(page_content="Row(body_blob='text1')"),
Document(page_content='{"body_blob": "text2"}'), Document(page_content="Row(body_blob='text2')"),
] ]
@@ -103,7 +103,7 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper) loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper)
assert loader.load() == [ assert loader.load() == [
Document( Document(
page_content='{"row_id": "id1", "body_blob": "text1"}', page_content="Row(row_id='id1', body_blob='text1')",
metadata={ metadata={
"table": CASSANDRA_TABLE, "table": CASSANDRA_TABLE,
"keyspace": keyspace, "keyspace": keyspace,
@@ -111,7 +111,7 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
}, },
), ),
Document( Document(
page_content='{"row_id": "id2", "body_blob": "text2"}', page_content="Row(row_id='id2', body_blob='text2')",
metadata={ metadata={
"table": CASSANDRA_TABLE, "table": CASSANDRA_TABLE,
"keyspace": keyspace, "keyspace": keyspace,

View File

@@ -37,6 +37,7 @@ EXPECTED_ALL = [
"BlockchainDocumentLoader", "BlockchainDocumentLoader",
"BraveSearchLoader", "BraveSearchLoader",
"BrowserlessLoader", "BrowserlessLoader",
"CassandraLoader",
"CSVLoader", "CSVLoader",
"ChatGPTLoader", "ChatGPTLoader",
"CoNLLULoader", "CoNLLULoader",