Compare commits

...

3 Commits

Author SHA1 Message Date
Evgene Yurtsev
621c8b914d x 2024-01-30 17:31:55 -08:00
Christophe Bornet
d092f8a013 Add test of default methods 2024-01-27 13:04:07 +01:00
Christophe Bornet
a688185b83 Add async methods in BaseLoader 2024-01-27 12:53:26 +01:00
5 changed files with 69 additions and 56 deletions

View File

@@ -2,9 +2,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Iterator, List, Optional
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional
from langchain_core.documents import Document
from langchain_core.runnables import run_in_executor
from langchain_community.document_loaders.blob_loaders import Blob
@@ -52,14 +53,22 @@ class BaseLoader(ABC):
# Attention: This method will be upgraded into an abstractmethod once it's
# implemented in all the existing subclasses.
def lazy_load(
self,
) -> Iterator[Document]:
def lazy_load(self) -> Iterator[Document]:
"""A lazy loader for Documents."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement lazy_load()"
)
async def alazy_load(self) -> AsyncIterator[Document]:
"""A lazy loader for Documents."""
iterator = await run_in_executor(None, self.lazy_load)
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
if doc is done:
break
yield doc
class BaseBlobParser(ABC):
"""Abstract interface for blob parsers.

View File

@@ -1,4 +1,4 @@
from typing import Iterator, List
from typing import AsyncIterator, Iterator, List
from langchain_core.documents import Document
@@ -26,3 +26,9 @@ class MergedDataLoader(BaseLoader):
def load(self) -> List[Document]:
"""Load docs."""
return list(self.lazy_load())
async def alazy_load(self) -> AsyncIterator[Document]:
"""Lazy load docs from each individual loader."""
for loader in self.loaders:
async for document in loader.alazy_load():
yield document

View File

@@ -1,9 +1,9 @@
"""Test Base Schema of documents."""
from typing import Iterator
from typing import Iterator, List
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseBlobParser
from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader
from langchain_community.document_loaders.blob_loaders import Blob
@@ -27,3 +27,25 @@ def test_base_blob_parser() -> None:
docs = parser.parse(Blob(data="who?"))
assert len(docs) == 1
assert docs[0].page_content == "foo"
async def test_default_aload() -> None:
class FakeLoader(BaseLoader):
def load(self) -> List[Document]:
return list(self.lazy_load())
def lazy_load(self) -> Iterator[Document]:
yield from [
Document(page_content="foo"),
Document(page_content="bar")
]
loader = FakeLoader()
docs = loader.load()
assert docs == [Document(page_content="foo"), Document(page_content="bar")]
# Test that async lazy loading works
docs = [doc async for doc in loader.alazy_load()]
assert docs == [Document(page_content="foo"), Document(page_content="bar")]

View File

@@ -391,7 +391,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
async def aindex(
docs_source: Union[Iterable[Document], AsyncIterator[Document]],
docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]],
record_manager: RecordManager,
vector_store: VectorStore,
*,
@@ -469,16 +469,17 @@ async def aindex(
# implementation which just raises a NotImplementedError
raise ValueError("Vectorstore has not implemented the delete method")
if isinstance(docs_source, BaseLoader):
raise NotImplementedError(
"Not supported yet. Please pass an async iterator of documents."
)
async_doc_iterator: AsyncIterator[Document]
if hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
if isinstance(docs_source, BaseLoader):
try:
async_doc_iterator = docs_source.alazy_load()
except NotImplementedError:
async_doc_iterator = _to_async_iterator(await docs_source.aload())
else:
async_doc_iterator = _to_async_iterator(docs_source)
if hasattr(docs_source, "__aiter__"):
async_doc_iterator = docs_source # type: ignore[assignment]
else:
async_doc_iterator = _to_async_iterator(docs_source)
source_id_assigner = _get_source_id_assigner(source_id_key)

View File

@@ -40,19 +40,6 @@ class ToyLoader(BaseLoader):
"""Load the documents from the source."""
return list(self.lazy_load())
async def alazy_load(
self,
) -> AsyncIterator[Document]:
async def async_generator() -> AsyncIterator[Document]:
for document in self.documents:
yield document
return async_generator()
async def aload(self) -> List[Document]:
"""Load the documents from the source."""
return [doc async for doc in await self.alazy_load()]
class InMemoryVectorStore(VectorStore):
"""In-memory implementation of VectorStore using a dictionary."""
@@ -232,7 +219,7 @@ async def test_aindexing_same_content(
]
)
assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == {
assert await aindex(loader, arecord_manager, vector_store) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
@@ -243,9 +230,7 @@ async def test_aindexing_same_content(
for _ in range(2):
# Run the indexing again
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store
) == {
assert await aindex(loader, arecord_manager, vector_store) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -347,9 +332,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
@@ -359,9 +342,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -382,9 +363,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 1,
"num_deleted": 1,
"num_skipped": 1,
@@ -402,9 +381,7 @@ async def test_aindex_simple_delete_full(
with patch.object(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
@@ -473,7 +450,7 @@ async def test_aincremental_fails_with_bad_source_ids(
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup="incremental",
@@ -482,7 +459,7 @@ async def test_aincremental_fails_with_bad_source_ids(
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup="incremental",
@@ -593,7 +570,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup=None,
@@ -610,7 +587,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup=None,
@@ -640,7 +617,7 @@ async def test_ano_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader,
arecord_manager,
vector_store,
cleanup=None,
@@ -779,7 +756,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader.lazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -803,7 +780,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader.lazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -838,7 +815,7 @@ async def test_aincremental_delete(
arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp()
):
assert await aindex(
await loader.alazy_load(),
loader.lazy_load(),
arecord_manager,
vector_store,
cleanup="incremental",
@@ -883,9 +860,7 @@ async def test_aindexing_with_no_docs(
"""Check edge case when loader returns no new docs."""
loader = ToyLoader(documents=[])
assert await aindex(
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
) == {
assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 0,