mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
community[minor]: Add async methods to CassandraLoader (#20609)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
8c29b7bf35
commit
d2d01370bc
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Iterator,
|
||||
Optional,
|
||||
@ -13,6 +14,7 @@ from typing import (
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_community.utilities.cassandra import wrapped_response_future
|
||||
|
||||
_NOT_SET = object()
|
||||
|
||||
@ -112,3 +114,15 @@ class CassandraLoader(BaseLoader):
|
||||
yield Document(
|
||||
page_content=self.page_content_mapper(row), metadata=metadata
|
||||
)
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
for row in await wrapped_response_future(
|
||||
self.session.execute_async,
|
||||
self.query,
|
||||
**self.query_kwargs,
|
||||
):
|
||||
metadata = self.metadata.copy()
|
||||
metadata.update(self.metadata_mapper(row))
|
||||
yield Document(
|
||||
page_content=self.page_content_mapper(row), metadata=metadata
|
||||
)
|
||||
|
24
libs/community/langchain_community/utilities/cassandra.py
Normal file
24
libs/community/langchain_community/utilities/cassandra.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cassandra.cluster import ResponseFuture
|
||||
|
||||
|
||||
async def wrapped_response_future(
|
||||
func: Callable[..., ResponseFuture], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio_future = loop.create_future()
|
||||
response_future = func(*args, **kwargs)
|
||||
|
||||
def success_handler(_: Any) -> None:
|
||||
loop.call_soon_threadsafe(asyncio_future.set_result, response_future.result())
|
||||
|
||||
def error_handler(exc: BaseException) -> None:
|
||||
loop.call_soon_threadsafe(asyncio_future.set_exception, exc)
|
||||
|
||||
response_future.add_callbacks(success_handler, error_handler)
|
||||
return await asyncio_future
|
@ -55,9 +55,9 @@ def keyspace() -> Iterator[str]:
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{CASSANDRA_TABLE}")
|
||||
|
||||
|
||||
def test_loader_table(keyspace: str) -> None:
|
||||
async def test_loader_table(keyspace: str) -> None:
|
||||
loader = CassandraLoader(table=CASSANDRA_TABLE)
|
||||
assert loader.load() == [
|
||||
expected = [
|
||||
Document(
|
||||
page_content="Row(row_id='id1', body_blob='text1')",
|
||||
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
|
||||
@ -67,24 +67,28 @@ def test_loader_table(keyspace: str) -> None:
|
||||
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
|
||||
),
|
||||
]
|
||||
assert loader.load() == expected
|
||||
assert await loader.aload() == expected
|
||||
|
||||
|
||||
def test_loader_query(keyspace: str) -> None:
|
||||
async def test_loader_query(keyspace: str) -> None:
|
||||
loader = CassandraLoader(
|
||||
query=f"SELECT body_blob FROM {keyspace}.{CASSANDRA_TABLE}"
|
||||
)
|
||||
assert loader.load() == [
|
||||
expected = [
|
||||
Document(page_content="Row(body_blob='text1')"),
|
||||
Document(page_content="Row(body_blob='text2')"),
|
||||
]
|
||||
assert loader.load() == expected
|
||||
assert await loader.aload() == expected
|
||||
|
||||
|
||||
def test_loader_page_content_mapper(keyspace: str) -> None:
|
||||
async def test_loader_page_content_mapper(keyspace: str) -> None:
|
||||
def mapper(row: Any) -> str:
|
||||
return str(row.body_blob)
|
||||
|
||||
loader = CassandraLoader(table=CASSANDRA_TABLE, page_content_mapper=mapper)
|
||||
assert loader.load() == [
|
||||
expected = [
|
||||
Document(
|
||||
page_content="text1",
|
||||
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
|
||||
@ -94,14 +98,16 @@ def test_loader_page_content_mapper(keyspace: str) -> None:
|
||||
metadata={"table": CASSANDRA_TABLE, "keyspace": keyspace},
|
||||
),
|
||||
]
|
||||
assert loader.load() == expected
|
||||
assert await loader.aload() == expected
|
||||
|
||||
|
||||
def test_loader_metadata_mapper(keyspace: str) -> None:
|
||||
async def test_loader_metadata_mapper(keyspace: str) -> None:
|
||||
def mapper(row: Any) -> dict:
|
||||
return {"id": row.row_id}
|
||||
|
||||
loader = CassandraLoader(table=CASSANDRA_TABLE, metadata_mapper=mapper)
|
||||
assert loader.load() == [
|
||||
expected = [
|
||||
Document(
|
||||
page_content="Row(row_id='id1', body_blob='text1')",
|
||||
metadata={
|
||||
@ -119,3 +125,5 @@ def test_loader_metadata_mapper(keyspace: str) -> None:
|
||||
},
|
||||
),
|
||||
]
|
||||
assert loader.load() == expected
|
||||
assert await loader.aload() == expected
|
||||
|
Loading…
Reference in New Issue
Block a user