core[minor]: Improve support for id in VectorStore (#26660)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet
2024-10-07 21:01:08 +02:00
committed by GitHub
parent 931ce8d026
commit c4ebccfec2
3 changed files with 178 additions and 27 deletions

View File

@@ -10,8 +10,10 @@ import uuid
from collections.abc import Iterable, Sequence
from typing import Any, Optional
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.embeddings import Embeddings, FakeEmbeddings
from langchain_core.vectorstores import VectorStore
@@ -25,10 +27,6 @@ class CustomAddTextsVectorstore(VectorStore):
self,
texts: Iterable[str],
metadatas: Optional[list[dict]] = None,
# One of the kwargs should be `ids` which is a list of ids
# associated with the texts.
# This is not yet enforced in the type signature for backwards compatibility
# with existing implementations.
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
@@ -68,12 +66,59 @@ class CustomAddTextsVectorstore(VectorStore):
raise NotImplementedError()
def test_default_add_documents() -> None:
class CustomAddDocumentsVectorstore(VectorStore):
"""A vectorstore that only implements add documents."""
def __init__(self) -> None:
self.store: dict[str, Document] = {}
def add_documents(
self,
documents: list[Document],
*,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
ids_ = []
ids_iter = iter(ids or [])
for document in documents:
id_ = next(ids_iter) if ids else document.id or str(uuid.uuid4())
self.store[id_] = Document(
id=id_, page_content=document.page_content, metadata=document.metadata
)
ids_.append(id_)
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store]
@classmethod
def from_texts( # type: ignore
cls,
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> CustomAddDocumentsVectorstore:
vectorstore = CustomAddDocumentsVectorstore()
vectorstore.add_texts(texts, metadatas=metadatas, **kwargs)
return vectorstore
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
raise NotImplementedError()
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
def test_default_add_documents(vs_class: type[VectorStore]) -> None:
"""Test that we can implement the upsert method of the CustomVectorStore
class without violating the Liskov Substitution Principle.
"""
store = CustomAddTextsVectorstore()
store = vs_class()
# Check upsert with id
assert store.add_documents([Document(id="1", page_content="hello")]) == ["1"]
@@ -95,8 +140,11 @@ def test_default_add_documents() -> None:
assert store.get_by_ids(["6"]) == [Document(id="6", page_content="baz")]
def test_default_add_texts() -> None:
store = CustomAddTextsVectorstore()
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
def test_default_add_texts(vs_class: type[VectorStore]) -> None:
store = vs_class()
# Check that default implementation of add_texts works
assert store.add_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"]
@@ -122,9 +170,12 @@ def test_default_add_texts() -> None:
]
async def test_default_aadd_documents() -> None:
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
async def test_default_aadd_documents(vs_class: type[VectorStore]) -> None:
"""Test delegation to the synchronous method."""
store = CustomAddTextsVectorstore()
store = vs_class()
# Check upsert with id
assert await store.aadd_documents([Document(id="1", page_content="hello")]) == ["1"]
@@ -146,10 +197,13 @@ async def test_default_aadd_documents() -> None:
assert await store.aget_by_ids(["6"]) == [Document(id="6", page_content="baz")]
async def test_default_aadd_texts() -> None:
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
async def test_default_aadd_texts(vs_class: type[VectorStore]) -> None:
"""Test delegation to the synchronous method."""
store = CustomAddTextsVectorstore()
# Check that default implementation of add_texts works
store = vs_class()
# Check that default implementation of aadd_texts works
assert await store.aadd_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"]
assert await store.aget_by_ids(["3", "4"]) == [
@@ -172,3 +226,61 @@ async def test_default_aadd_texts() -> None:
Document(id=ids_2[0], page_content="foo", metadata={"foo": "bar"}),
Document(id=ids_2[1], page_content="bar", metadata={"foo": "bar"}),
]
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
def test_default_from_documents(vs_class: type[VectorStore]) -> None:
embeddings = FakeEmbeddings(size=1)
store = vs_class.from_documents(
[Document(id="1", page_content="hello", metadata={"foo": "bar"})], embeddings
)
assert store.get_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# from_documents with ids in args
store = vs_class.from_documents(
[Document(page_content="hello", metadata={"foo": "bar"})], embeddings, ids=["1"]
)
assert store.get_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# Test from_documents with id specified in both document and ids
original_document = Document(id="7", page_content="baz")
store = vs_class.from_documents([original_document], embeddings, ids=["6"])
assert original_document.id == "7" # original document should not be modified
assert store.get_by_ids(["6"]) == [Document(id="6", page_content="baz")]
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
async def test_default_afrom_documents(vs_class: type[VectorStore]) -> None:
embeddings = FakeEmbeddings(size=1)
store = await vs_class.afrom_documents(
[Document(id="1", page_content="hello", metadata={"foo": "bar"})], embeddings
)
assert await store.aget_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# from_documents with ids in args
store = await vs_class.afrom_documents(
[Document(page_content="hello", metadata={"foo": "bar"})], embeddings, ids=["1"]
)
assert await store.aget_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# Test afrom_documents with id specified in both document and ids
original_document = Document(id="7", page_content="baz")
store = await vs_class.afrom_documents([original_document], embeddings, ids=["6"])
assert original_document.id == "7" # original document should not be modified
assert await store.aget_by_ids(["6"]) == [Document(id="6", page_content="baz")]