mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
core[minor]: Support asynchronous in InMemoryVectorStore (#24472)
### Description * support asynchronous in InMemoryVectorStore * since embeddings might be possible to call asynchronously, ensure that both asynchronous and synchronous functions operate correctly.
This commit is contained in:
parent
5fdbdd6bec
commit
256bad3251
@ -8,7 +8,6 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -74,6 +73,27 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
"failed": [],
|
"failed": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def aupsert(
|
||||||
|
self, items: Sequence[Document], /, **kwargs: Any
|
||||||
|
) -> UpsertResponse:
|
||||||
|
vectors = await self.embedding.aembed_documents(
|
||||||
|
[item.page_content for item in items]
|
||||||
|
)
|
||||||
|
ids = []
|
||||||
|
for item, vector in zip(items, vectors):
|
||||||
|
doc_id = item.id if item.id else str(uuid.uuid4())
|
||||||
|
ids.append(doc_id)
|
||||||
|
self.store[doc_id] = {
|
||||||
|
"id": doc_id,
|
||||||
|
"vector": vector,
|
||||||
|
"text": item.page_content,
|
||||||
|
"metadata": item.metadata,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"succeeded": ids,
|
||||||
|
"failed": [],
|
||||||
|
}
|
||||||
|
|
||||||
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
|
||||||
"""Get documents by their ids.
|
"""Get documents by their ids.
|
||||||
|
|
||||||
@ -108,14 +128,6 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
"""
|
"""
|
||||||
return self.get_by_ids(ids)
|
return self.get_by_ids(ids)
|
||||||
|
|
||||||
async def aadd_texts(
|
|
||||||
self,
|
|
||||||
texts: Iterable[str],
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[str]:
|
|
||||||
return self.add_texts(texts, metadatas, **kwargs)
|
|
||||||
|
|
||||||
def _similarity_search_with_score_by_vector(
|
def _similarity_search_with_score_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: List[float],
|
embedding: List[float],
|
||||||
@ -172,7 +184,13 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
async def asimilarity_search_with_score(
|
async def asimilarity_search_with_score(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
return self.similarity_search_with_score(query, k, **kwargs)
|
embedding = await self.embedding.aembed_query(query)
|
||||||
|
docs = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding,
|
||||||
|
k,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
def similarity_search_by_vector(
|
||||||
self,
|
self,
|
||||||
@ -200,7 +218,10 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
async def asimilarity_search(
|
async def asimilarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
return self.similarity_search(query, k, **kwargs)
|
return [
|
||||||
|
doc
|
||||||
|
for doc, _ in await self.asimilarity_search_with_score(query, k, **kwargs)
|
||||||
|
]
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
def max_marginal_relevance_search_by_vector(
|
||||||
self,
|
self,
|
||||||
@ -249,6 +270,23 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def amax_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
embedding_vector = await self.embedding.aembed_query(query)
|
||||||
|
return self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding_vector,
|
||||||
|
k,
|
||||||
|
fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls,
|
||||||
@ -271,7 +309,11 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "InMemoryVectorStore":
|
) -> "InMemoryVectorStore":
|
||||||
return cls.from_texts(texts, embedding, metadatas, **kwargs)
|
store = cls(
|
||||||
|
embedding=embedding,
|
||||||
|
)
|
||||||
|
await store.aadd_texts(texts=texts, metadatas=metadatas, **kwargs)
|
||||||
|
return store
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_standard_tests.integration_tests.vectorstores import (
|
from langchain_standard_tests.integration_tests.vectorstores import (
|
||||||
@ -24,25 +25,39 @@ class TestAsyncInMemoryReadWriteTestSuite(AsyncReadWriteTestSuite):
|
|||||||
return InMemoryVectorStore(embedding=self.get_embeddings())
|
return InMemoryVectorStore(embedding=self.get_embeddings())
|
||||||
|
|
||||||
|
|
||||||
async def test_inmemory() -> None:
|
async def test_inmemory_similarity_search() -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end similarity search."""
|
||||||
store = await InMemoryVectorStore.afrom_texts(
|
store = await InMemoryVectorStore.afrom_texts(
|
||||||
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=6)
|
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
|
||||||
)
|
)
|
||||||
output = await store.asimilarity_search("foo", k=1)
|
|
||||||
|
# Check sync version
|
||||||
|
output = store.similarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo", id=AnyStr())]
|
assert output == [Document(page_content="foo", id=AnyStr())]
|
||||||
|
|
||||||
|
# Check async version
|
||||||
output = await store.asimilarity_search("bar", k=2)
|
output = await store.asimilarity_search("bar", k=2)
|
||||||
assert output == [
|
assert output == [
|
||||||
Document(page_content="bar", id=AnyStr()),
|
Document(page_content="bar", id=AnyStr()),
|
||||||
Document(page_content="baz", id=AnyStr()),
|
Document(page_content="baz", id=AnyStr()),
|
||||||
]
|
]
|
||||||
|
|
||||||
output2 = await store.asimilarity_search_with_score("bar", k=2)
|
|
||||||
assert output2[0][1] > output2[1][1]
|
async def test_inmemory_similarity_search_with_score() -> None:
|
||||||
|
"""Test end to end similarity search with score"""
|
||||||
|
store = await InMemoryVectorStore.afrom_texts(
|
||||||
|
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
|
||||||
|
)
|
||||||
|
|
||||||
|
output = store.similarity_search_with_score("foo", k=1)
|
||||||
|
assert output[0][0].page_content == "foo"
|
||||||
|
|
||||||
|
output = await store.asimilarity_search_with_score("bar", k=2)
|
||||||
|
assert output[0][1] > output[1][1]
|
||||||
|
|
||||||
|
|
||||||
async def test_add_by_ids() -> None:
|
async def test_add_by_ids() -> None:
|
||||||
|
"""Test add texts with ids."""
|
||||||
vectorstore = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=6))
|
vectorstore = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=6))
|
||||||
|
|
||||||
# Check sync version
|
# Check sync version
|
||||||
@ -50,17 +65,25 @@ async def test_add_by_ids() -> None:
|
|||||||
assert ids1 == ["1", "2", "3"]
|
assert ids1 == ["1", "2", "3"]
|
||||||
assert sorted(vectorstore.store.keys()) == ["1", "2", "3"]
|
assert sorted(vectorstore.store.keys()) == ["1", "2", "3"]
|
||||||
|
|
||||||
|
# Check async version
|
||||||
ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"])
|
ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"])
|
||||||
assert ids2 == ["4", "5", "6"]
|
assert ids2 == ["4", "5", "6"]
|
||||||
assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"]
|
assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"]
|
||||||
|
|
||||||
|
|
||||||
async def test_inmemory_mmr() -> None:
|
async def test_inmemory_mmr() -> None:
|
||||||
|
"""Test MMR search"""
|
||||||
texts = ["foo", "foo", "fou", "foy"]
|
texts = ["foo", "foo", "fou", "foy"]
|
||||||
docsearch = await InMemoryVectorStore.afrom_texts(
|
docsearch = await InMemoryVectorStore.afrom_texts(
|
||||||
texts, DeterministicFakeEmbedding(size=6)
|
texts, DeterministicFakeEmbedding(size=6)
|
||||||
)
|
)
|
||||||
# make sure we can k > docstore size
|
# make sure we can k > docstore size
|
||||||
|
output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1)
|
||||||
|
assert len(output) == len(texts)
|
||||||
|
assert output[0] == Document(page_content="foo", id=AnyStr())
|
||||||
|
assert output[1] == Document(page_content="foy", id=AnyStr())
|
||||||
|
|
||||||
|
# Check async version
|
||||||
output = await docsearch.amax_marginal_relevance_search(
|
output = await docsearch.amax_marginal_relevance_search(
|
||||||
"foo", k=10, lambda_mult=0.1
|
"foo", k=10, lambda_mult=0.1
|
||||||
)
|
)
|
||||||
@ -85,13 +108,91 @@ async def test_inmemory_dump_load(tmp_path: Path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_inmemory_filter() -> None:
|
async def test_inmemory_filter() -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search with filter."""
|
||||||
store = await InMemoryVectorStore.afrom_texts(
|
store = await InMemoryVectorStore.afrom_texts(
|
||||||
["foo", "bar"],
|
["foo", "bar"],
|
||||||
DeterministicFakeEmbedding(size=6),
|
DeterministicFakeEmbedding(size=6),
|
||||||
[{"id": 1}, {"id": 2}],
|
[{"id": 1}, {"id": 2}],
|
||||||
)
|
)
|
||||||
output = await store.asimilarity_search(
|
|
||||||
"baz", filter=lambda doc: doc.metadata["id"] == 1
|
# Check sync version
|
||||||
)
|
output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1)
|
||||||
assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())]
|
assert output == [Document(page_content="foo", metadata={"id": 1}, id=AnyStr())]
|
||||||
|
|
||||||
|
# filter with not stored document id
|
||||||
|
output = await store.asimilarity_search(
|
||||||
|
"baz", filter=lambda doc: doc.metadata["id"] == 3
|
||||||
|
)
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_inmemory_upsert() -> None:
|
||||||
|
"""Test upsert documents."""
|
||||||
|
embedding = DeterministicFakeEmbedding(size=2)
|
||||||
|
store = InMemoryVectorStore(embedding=embedding)
|
||||||
|
|
||||||
|
# Check sync version
|
||||||
|
store.upsert([Document(page_content="foo", id="1")])
|
||||||
|
assert sorted(store.store.keys()) == ["1"]
|
||||||
|
|
||||||
|
# Check async version
|
||||||
|
await store.aupsert([Document(page_content="bar", id="2")])
|
||||||
|
assert sorted(store.store.keys()) == ["1", "2"]
|
||||||
|
|
||||||
|
# update existing document
|
||||||
|
await store.aupsert(
|
||||||
|
[Document(page_content="baz", id="2", metadata={"metadata": "value"})]
|
||||||
|
)
|
||||||
|
item = store.store["2"]
|
||||||
|
|
||||||
|
baz_vector = embedding.embed_query("baz")
|
||||||
|
assert item == {
|
||||||
|
"id": "2",
|
||||||
|
"text": "baz",
|
||||||
|
"vector": baz_vector,
|
||||||
|
"metadata": {"metadata": "value"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_inmemory_get_by_ids() -> None:
|
||||||
|
"""Test get by ids."""
|
||||||
|
|
||||||
|
store = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=3))
|
||||||
|
|
||||||
|
store.upsert(
|
||||||
|
[
|
||||||
|
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
|
||||||
|
Document(page_content="bar", id="2"),
|
||||||
|
Document(page_content="baz", id="3"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check sync version
|
||||||
|
output = store.get_by_ids(["1", "2"])
|
||||||
|
assert output == [
|
||||||
|
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
|
||||||
|
Document(page_content="bar", id="2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check async version
|
||||||
|
output = await store.aget_by_ids(["1", "3", "5"])
|
||||||
|
assert output == [
|
||||||
|
Document(page_content="foo", id="1", metadata={"metadata": "value"}),
|
||||||
|
Document(page_content="baz", id="3"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_inmemory_call_embeddings_async() -> None:
|
||||||
|
embeddings_mock = Mock(
|
||||||
|
wraps=DeterministicFakeEmbedding(size=3),
|
||||||
|
aembed_documents=AsyncMock(),
|
||||||
|
aembed_query=AsyncMock(),
|
||||||
|
)
|
||||||
|
store = InMemoryVectorStore(embedding=embeddings_mock)
|
||||||
|
|
||||||
|
await store.aadd_texts("foo")
|
||||||
|
await store.asimilarity_search("foo", k=1)
|
||||||
|
|
||||||
|
# Ensure the async embedding function is called
|
||||||
|
assert embeddings_mock.aembed_documents.await_count == 1
|
||||||
|
assert embeddings_mock.aembed_query.await_count == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user