langchain[minor],community[minor]: Add async methods in BaseLoader (#16634)

Adds:
* methods `aload()` and `alazy_load()` to interface `BaseLoader`
* implementation for class `MergedDataLoader `
* support for class `BaseLoader` in async function `aindex()` with unit
tests

Note: this is compatible with existing `aload()` methods that some
loaders already had.

**Twitter handle:** @cbornet_

---------

Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
This commit is contained in:
Christophe Bornet
2024-01-31 20:08:11 +01:00
committed by GitHub
parent c37ca45825
commit af8c5c185b
5 changed files with 71 additions and 52 deletions

View File

@@ -391,7 +391,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
async def aindex(
docs_source: Union[Iterable[Document], AsyncIterator[Document]],
docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]],
record_manager: RecordManager,
vector_store: VectorStore,
*,
@@ -469,16 +469,22 @@ async def aindex(
# implementation which just raises a NotImplementedError
raise ValueError("Vectorstore has not implemented the delete method")
if isinstance(docs_source, BaseLoader):
raise NotImplementedError(
"Not supported yet. Please pass an async iterator of documents."
)
async_doc_iterator: AsyncIterator[Document]
if hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
if isinstance(docs_source, BaseLoader):
try:
async_doc_iterator = docs_source.alazy_load()
except NotImplementedError:
# Exception triggered when neither lazy_load nor alazy_load are implemented.
# * The default implementation of alazy_load uses lazy_load.
# * The default implementation of lazy_load raises NotImplementedError.
# In such a case, we use the load method and convert it to an async
# iterator.
async_doc_iterator = _to_async_iterator(docs_source.load())
else:
async_doc_iterator = _to_async_iterator(docs_source)
if hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
else:
async_doc_iterator = _to_async_iterator(docs_source)
source_id_assigner = _get_source_id_assigner(source_id_key)

View File

@@ -43,15 +43,8 @@ class ToyLoader(BaseLoader):
async def alazy_load(
self,
) -> AsyncIterator[Document]:
async def async_generator() -> AsyncIterator[Document]:
for document in self.documents:
yield document
return async_generator()
async def aload(self) -> List[Document]:
"""Load the documents from the source."""
return [doc async for doc in await self.alazy_load()]
for document in self.documents:
yield document
class InMemoryVectorStore(VectorStore):
@@ -232,7 +225,7 @@ async def test_aindexing_same_content(
]
)
assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == {
assert await aindex(loader, arecord_manager, vector_store) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
@@ -243,9 +236,7 @@ async def test_aindexing_same_content(
for _ in range(2):
# Run the indexing again
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store
) == {
assert await aindex(loader, arecord_manager, vector_store) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -347,9 +338,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
@@ -359,9 +348,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -382,9 +369,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 1,
"num_deleted": 1,
"num_skipped": 1,
@@ -402,9 +387,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -473,7 +456,7 @@ async def test_aincremental_fails_with_bad_source_ids(
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup="incremental",
@@ -482,7 +465,7 @@ async def test_aincremental_fails_with_bad_source_ids(
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup="incremental",
@@ -593,7 +576,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup=None,
@@ -610,7 +593,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup=None,
@@ -640,7 +623,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup=None,
@@ -779,7 +762,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader.lazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -803,7 +786,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader.lazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -838,7 +821,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader.lazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -883,9 +866,7 @@ async def test_aindexing_with_no_docs(
"""Check edge case when loader returns no new docs."""
loader = ToyLoader(documents=[])
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 0,