mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
langchain[minor],community[minor]: Add async methods in BaseLoader (#16634)
Adds: * methods `aload()` and `alazy_load()` to interface `BaseLoader` * implementation for class `MergedDataLoader ` * support for class `BaseLoader` in async function `aindex()` with unit tests Note: this is compatible with existing `aload()` methods that some loaders already had. **Twitter handle:** @cbornet_ --------- Co-authored-by: Eugene Yurtsev <eugene@langchain.dev>
This commit is contained in:
parent
c37ca45825
commit
af8c5c185b
@ -2,9 +2,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Iterator, List, Optional
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
from langchain_community.document_loaders.blob_loaders import Blob
|
||||
|
||||
@ -52,14 +53,22 @@ class BaseLoader(ABC):
|
||||
|
||||
# Attention: This method will be upgraded into an abstractmethod once it's
|
||||
# implemented in all the existing subclasses.
|
||||
def lazy_load(
|
||||
self,
|
||||
) -> Iterator[Document]:
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""A lazy loader for Documents."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement lazy_load()"
|
||||
)
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""A lazy loader for Documents."""
|
||||
iterator = await run_in_executor(None, self.lazy_load)
|
||||
done = object()
|
||||
while True:
|
||||
doc = await run_in_executor(None, next, iterator, done)
|
||||
if doc is done:
|
||||
break
|
||||
yield doc
|
||||
|
||||
|
||||
class BaseBlobParser(ABC):
|
||||
"""Abstract interface for blob parsers.
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Iterator, List
|
||||
from typing import AsyncIterator, Iterator, List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -26,3 +26,9 @@ class MergedDataLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
"""Load docs."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Lazy load docs from each individual loader."""
|
||||
for loader in self.loaders:
|
||||
async for document in loader.alazy_load():
|
||||
yield document
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""Test Base Schema of documents."""
|
||||
from typing import Iterator
|
||||
from typing import Iterator, List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseBlobParser
|
||||
from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader
|
||||
from langchain_community.document_loaders.blob_loaders import Blob
|
||||
|
||||
|
||||
@ -27,3 +27,20 @@ def test_base_blob_parser() -> None:
|
||||
docs = parser.parse(Blob(data="who?"))
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "foo"
|
||||
|
||||
|
||||
async def test_default_aload() -> None:
|
||||
class FakeLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
return list(self.lazy_load())
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
yield from [
|
||||
Document(page_content="foo"),
|
||||
Document(page_content="bar"),
|
||||
]
|
||||
|
||||
loader = FakeLoader()
|
||||
docs = loader.load()
|
||||
assert docs == [Document(page_content="foo"), Document(page_content="bar")]
|
||||
assert docs == [doc async for doc in loader.alazy_load()]
|
||||
|
@ -391,7 +391,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
|
||||
|
||||
|
||||
async def aindex(
|
||||
docs_source: Union[Iterable[Document], AsyncIterator[Document]],
|
||||
docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]],
|
||||
record_manager: RecordManager,
|
||||
vector_store: VectorStore,
|
||||
*,
|
||||
@ -469,16 +469,22 @@ async def aindex(
|
||||
# implementation which just raises a NotImplementedError
|
||||
raise ValueError("Vectorstore has not implemented the delete method")
|
||||
|
||||
if isinstance(docs_source, BaseLoader):
|
||||
raise NotImplementedError(
|
||||
"Not supported yet. Please pass an async iterator of documents."
|
||||
)
|
||||
async_doc_iterator: AsyncIterator[Document]
|
||||
|
||||
if hasattr(docs_source, "__aiter__"):
|
||||
async_doc_iterator = docs_source # type: ignore[assignment]
|
||||
if isinstance(docs_source, BaseLoader):
|
||||
try:
|
||||
async_doc_iterator = docs_source.alazy_load()
|
||||
except NotImplementedError:
|
||||
# Exception triggered when neither lazy_load nor alazy_load are implemented.
|
||||
# * The default implementation of alazy_load uses lazy_load.
|
||||
# * The default implementation of lazy_load raises NotImplementedError.
|
||||
# In such a case, we use the load method and convert it to an async
|
||||
# iterator.
|
||||
async_doc_iterator = _to_async_iterator(docs_source.load())
|
||||
else:
|
||||
async_doc_iterator = _to_async_iterator(docs_source)
|
||||
if hasattr(docs_source, "__aiter__"):
|
||||
async_doc_iterator = docs_source # type: ignore[assignment]
|
||||
else:
|
||||
async_doc_iterator = _to_async_iterator(docs_source)
|
||||
|
||||
source_id_assigner = _get_source_id_assigner(source_id_key)
|
||||
|
||||
|
@ -43,15 +43,8 @@ class ToyLoader(BaseLoader):
|
||||
async def alazy_load(
|
||||
self,
|
||||
) -> AsyncIterator[Document]:
|
||||
async def async_generator() -> AsyncIterator[Document]:
|
||||
for document in self.documents:
|
||||
yield document
|
||||
|
||||
return async_generator()
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
"""Load the documents from the source."""
|
||||
return [doc async for doc in await self.alazy_load()]
|
||||
for document in self.documents:
|
||||
yield document
|
||||
|
||||
|
||||
class InMemoryVectorStore(VectorStore):
|
||||
@ -232,7 +225,7 @@ async def test_aindexing_same_content(
|
||||
]
|
||||
)
|
||||
|
||||
assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == {
|
||||
assert await aindex(loader, arecord_manager, vector_store) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
@ -243,9 +236,7 @@ async def test_aindexing_same_content(
|
||||
|
||||
for _ in range(2):
|
||||
# Run the indexing again
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store
|
||||
) == {
|
||||
assert await aindex(loader, arecord_manager, vector_store) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
@ -347,9 +338,7 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
@ -359,9 +348,7 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
@ -382,9 +369,7 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
"num_added": 1,
|
||||
"num_deleted": 1,
|
||||
"num_skipped": 1,
|
||||
@ -402,9 +387,7 @@ async def test_aindex_simple_delete_full(
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
@ -473,7 +456,7 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
with pytest.raises(ValueError):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader,
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@ -482,7 +465,7 @@ async def test_aincremental_fails_with_bad_source_ids(
|
||||
with pytest.raises(ValueError):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader,
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@ -593,7 +576,7 @@ async def test_ano_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader,
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
@ -610,7 +593,7 @@ async def test_ano_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader,
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
@ -640,7 +623,7 @@ async def test_ano_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader,
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
@ -779,7 +762,7 @@ async def test_aincremental_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader.lazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@ -803,7 +786,7 @@ async def test_aincremental_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader.lazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@ -838,7 +821,7 @@ async def test_aincremental_delete(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
loader.lazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
@ -883,9 +866,7 @@ async def test_aindexing_with_no_docs(
|
||||
"""Check edge case when loader returns no new docs."""
|
||||
loader = ToyLoader(documents=[])
|
||||
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
|
Loading…
Reference in New Issue
Block a user