core[minor]: Support asynchronous in InMemoryVectorStore (#24472)

### Description

* support asynchronous in InMemoryVectorStore
* since embeddings might be possible to call asynchronously, ensure that
both asynchronous and synchronous functions operate correctly.
This commit is contained in:
남광우 2024-07-26 00:36:55 +09:00 committed by GitHub
parent 5fdbdd6bec
commit 256bad3251
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 165 additions and 22 deletions

View File

@ -8,7 +8,6 @@ from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
Iterable,
List, List,
Optional, Optional,
Sequence, Sequence,
@ -74,6 +73,27 @@ class InMemoryVectorStore(VectorStore):
"failed": [], "failed": [],
} }
async def aupsert(
self, items: Sequence[Document], /, **kwargs: Any
) -> UpsertResponse:
vectors = await self.embedding.aembed_documents(
[item.page_content for item in items]
)
ids = []
for item, vector in zip(items, vectors):
doc_id = item.id if item.id else str(uuid.uuid4())
ids.append(doc_id)
self.store[doc_id] = {
"id": doc_id,
"vector": vector,
"text": item.page_content,
"metadata": item.metadata,
}
return {
"succeeded": ids,
"failed": [],
}
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
"""Get documents by their ids. """Get documents by their ids.
@ -108,14 +128,6 @@ class InMemoryVectorStore(VectorStore):
""" """
return self.get_by_ids(ids) return self.get_by_ids(ids)
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
return self.add_texts(texts, metadatas, **kwargs)
def _similarity_search_with_score_by_vector( def _similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
@ -172,7 +184,13 @@ class InMemoryVectorStore(VectorStore):
async def asimilarity_search_with_score( async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
return self.similarity_search_with_score(query, k, **kwargs) embedding = await self.embedding.aembed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
**kwargs,
)
return docs
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
@ -200,7 +218,10 @@ class InMemoryVectorStore(VectorStore):
async def asimilarity_search( async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
return self.similarity_search(query, k, **kwargs) return [
doc
for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)
]
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
@ -249,6 +270,23 @@ class InMemoryVectorStore(VectorStore):
**kwargs, **kwargs,
) )
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
embedding_vector = await self.embedding.aembed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
**kwargs,
)
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls,
@ -271,7 +309,11 @@ class InMemoryVectorStore(VectorStore):
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> "InMemoryVectorStore": ) -> "InMemoryVectorStore":
return cls.from_texts(texts, embedding, metadatas, **kwargs) store = cls(
embedding=embedding,
)
await store.aadd_texts(texts=texts, metadatas=metadatas, **kwargs)
return store
@classmethod @classmethod
def load( def load(

View File

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, Mock
import pytest import pytest
from langchain_standard_tests.integration_tests.vectorstores import ( from langchain_standard_tests.integration_tests.vectorstores import (
@ -24,25 +25,39 @@ class TestAsyncInMemoryReadWriteTestSuite(AsyncReadWriteTestSuite):
return InMemoryVectorStore(embedding=self.get_embeddings()) return InMemoryVectorStore(embedding=self.get_embeddings())
async def test_inmemory() -> None: async def test_inmemory_similarity_search() -> None:
"""Test end to end construction and search.""" """Test end to end similarity search."""
store = await InMemoryVectorStore.afrom_texts( store = await InMemoryVectorStore.afrom_texts(
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=6) ["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
) )
output = await store.asimilarity_search("foo", k=1)
# Check sync version
output = store.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())] assert output == [Document(page_content="foo", id=AnyStr())]
# Check async version
output = await store.asimilarity_search("bar", k=2) output = await store.asimilarity_search("bar", k=2)
assert output == [ assert output == [
Document(page_content="bar", id=AnyStr()), Document(page_content="bar", id=AnyStr()),
Document(page_content="baz", id=AnyStr()), Document(page_content="baz", id=AnyStr()),
] ]
output2 = await store.asimilarity_search_with_score("bar", k=2)
assert output2[0][1] > output2[1][1] async def test_inmemory_similarity_search_with_score() -> None:
"""Test end to end similarity search with score"""
store = await InMemoryVectorStore.afrom_texts(
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
)
output = store.similarity_search_with_score("foo", k=1)
assert output[0][0].page_content == "foo"
output = await store.asimilarity_search_with_score("bar", k=2)
assert output[0][1] > output[1][1]
async def test_add_by_ids() -> None: async def test_add_by_ids() -> None:
"""Test add texts with ids."""
vectorstore = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=6)) vectorstore = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=6))
# Check sync version # Check sync version
@ -50,17 +65,25 @@ async def test_add_by_ids() -> None:
assert ids1 == ["1", "2", "3"] assert ids1 == ["1", "2", "3"]
assert sorted(vectorstore.store.keys()) == ["1", "2", "3"] assert sorted(vectorstore.store.keys()) == ["1", "2", "3"]
# Check async version
ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"]) ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"])
assert ids2 == ["4", "5", "6"] assert ids2 == ["4", "5", "6"]
assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"] assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"]
async def test_inmemory_mmr() -> None: async def test_inmemory_mmr() -> None:
"""Test MMR search"""
texts = ["foo", "foo", "fou", "foy"] texts = ["foo", "foo", "fou", "foy"]
docsearch = await InMemoryVectorStore.afrom_texts( docsearch = await InMemoryVectorStore.afrom_texts(
texts, DeterministicFakeEmbedding(size=6) texts, DeterministicFakeEmbedding(size=6)
) )
# make sure we can k > docstore size # make sure we can k > docstore size
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0] == Document(page_content="foo", id=AnyStr())
assert output[1] == Document(page_content="foy", id=AnyStr())
# Check async version
output = await docsearch.amax_marginal_relevance_search( output = await docsearch.amax_marginal_relevance_search(
"foo", k=10, lambda_mult=0.1 "foo", k=10, lambda_mult=0.1
) )
@ -85,13 +108,91 @@ async def test_inmemory_dump_load(tmp_path: Path) -> None:
async def test_inmemory_filter() -> None: async def test_inmemory_filter() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search with filter."""
store = await InMemoryVectorStore.afrom_texts( store = await InMemoryVectorStore.afrom_texts(
["foo", "bar"], ["foo", "bar"],
DeterministicFakeEmbedding(size=6), DeterministicFakeEmbedding(size=6),
[{"id": 1}, {"id": 2}], [{"id": 1}, {"id": 2}],
) )
output = await store.asimilarity_search(
"baz", filter=lambda doc: doc.metadata["id"] == 1 # Check sync version
) output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())] assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())]
# filter with not stored document id
output = await store.asimilarity_search(
"baz", filter=lambda doc: doc.metadata["id"] == 3
)
assert output == []
async def test_inmemory_upsert() -> None:
"""Test upsert documents."""
embedding = DeterministicFakeEmbedding(size=2)
store = InMemoryVectorStore(embedding=embedding)
# Check sync version
store.upsert([Document(page_content="foo", id="1")])
assert sorted(store.store.keys()) == ["1"]
# Check async version
await store.aupsert([Document(page_content="bar", id="2")])
assert sorted(store.store.keys()) == ["1", "2"]
# update existing document
await store.aupsert(
[Document(page_content="baz", id="2", metadata={"metadata": "value"})]
)
item = store.store["2"]
baz_vector = embedding.embed_query("baz")
assert item == {
"id": "2",
"text": "baz",
"vector": baz_vector,
"metadata": {"metadata": "value"},
}
async def test_inmemory_get_by_ids() -> None:
"""Test get by ids."""
store = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=3))
store.upsert(
[
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
Document(page_content="bar", id="2"),
Document(page_content="baz", id="3"),
],
)
# Check sync version
output = store.get_by_ids(["1", "2"])
assert output == [
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
Document(page_content="bar", id="2"),
]
# Check async version
output = await store.aget_by_ids(["1", "3", "5"])
assert output == [
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
Document(page_content="baz", id="3"),
]
async def test_inmemory_call_embeddings_async() -> None:
embeddings_mock = Mock(
wraps=DeterministicFakeEmbedding(size=3),
aembed_documents=AsyncMock(),
aembed_query=AsyncMock(),
)
store = InMemoryVectorStore(embedding=embeddings_mock)
await store.aadd_texts("foo")
await store.asimilarity_search("foo", k=1)
# Ensure the async embedding function is called
assert embeddings_mock.aembed_documents.await_count == 1
assert embeddings_mock.aembed_query.await_count == 1