mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
core[minor]: Improve support for id in VectorStore (#26660)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
committed by
GitHub
parent
931ce8d026
commit
c4ebccfec2
@@ -10,8 +10,10 @@ import uuid
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
|
||||
@@ -25,10 +27,6 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
# One of the kwargs should be `ids` which is a list of ids
|
||||
# associated with the texts.
|
||||
# This is not yet enforced in the type signature for backwards compatibility
|
||||
# with existing implementations.
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
@@ -68,12 +66,59 @@ class CustomAddTextsVectorstore(VectorStore):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def test_default_add_documents() -> None:
|
||||
class CustomAddDocumentsVectorstore(VectorStore):
|
||||
"""A vectorstore that only implements add documents."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.store: dict[str, Document] = {}
|
||||
|
||||
def add_documents(
|
||||
self,
|
||||
documents: list[Document],
|
||||
*,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
ids_ = []
|
||||
ids_iter = iter(ids or [])
|
||||
for document in documents:
|
||||
id_ = next(ids_iter) if ids else document.id or str(uuid.uuid4())
|
||||
self.store[id_] = Document(
|
||||
id=id_, page_content=document.page_content, metadata=document.metadata
|
||||
)
|
||||
ids_.append(id_)
|
||||
return ids_
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
return [self.store[id] for id in ids if id in self.store]
|
||||
|
||||
@classmethod
|
||||
def from_texts( # type: ignore
|
||||
cls,
|
||||
texts: list[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CustomAddDocumentsVectorstore:
|
||||
vectorstore = CustomAddDocumentsVectorstore()
|
||||
vectorstore.add_texts(texts, metadatas=metadatas, **kwargs)
|
||||
return vectorstore
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
|
||||
)
|
||||
def test_default_add_documents(vs_class: type[VectorStore]) -> None:
|
||||
"""Test that we can implement the upsert method of the CustomVectorStore
|
||||
class without violating the Liskov Substitution Principle.
|
||||
"""
|
||||
|
||||
store = CustomAddTextsVectorstore()
|
||||
store = vs_class()
|
||||
|
||||
# Check upsert with id
|
||||
assert store.add_documents([Document(id="1", page_content="hello")]) == ["1"]
|
||||
@@ -95,8 +140,11 @@ def test_default_add_documents() -> None:
|
||||
assert store.get_by_ids(["6"]) == [Document(id="6", page_content="baz")]
|
||||
|
||||
|
||||
def test_default_add_texts() -> None:
|
||||
store = CustomAddTextsVectorstore()
|
||||
@pytest.mark.parametrize(
|
||||
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
|
||||
)
|
||||
def test_default_add_texts(vs_class: type[VectorStore]) -> None:
|
||||
store = vs_class()
|
||||
# Check that default implementation of add_texts works
|
||||
assert store.add_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"]
|
||||
|
||||
@@ -122,9 +170,12 @@ def test_default_add_texts() -> None:
|
||||
]
|
||||
|
||||
|
||||
async def test_default_aadd_documents() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
|
||||
)
|
||||
async def test_default_aadd_documents(vs_class: type[VectorStore]) -> None:
|
||||
"""Test delegation to the synchronous method."""
|
||||
store = CustomAddTextsVectorstore()
|
||||
store = vs_class()
|
||||
|
||||
# Check upsert with id
|
||||
assert await store.aadd_documents([Document(id="1", page_content="hello")]) == ["1"]
|
||||
@@ -146,10 +197,13 @@ async def test_default_aadd_documents() -> None:
|
||||
assert await store.aget_by_ids(["6"]) == [Document(id="6", page_content="baz")]
|
||||
|
||||
|
||||
async def test_default_aadd_texts() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
|
||||
)
|
||||
async def test_default_aadd_texts(vs_class: type[VectorStore]) -> None:
|
||||
"""Test delegation to the synchronous method."""
|
||||
store = CustomAddTextsVectorstore()
|
||||
# Check that default implementation of add_texts works
|
||||
store = vs_class()
|
||||
# Check that default implementation of aadd_texts works
|
||||
assert await store.aadd_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"]
|
||||
|
||||
assert await store.aget_by_ids(["3", "4"]) == [
|
||||
@@ -172,3 +226,61 @@ async def test_default_aadd_texts() -> None:
|
||||
Document(id=ids_2[0], page_content="foo", metadata={"foo": "bar"}),
|
||||
Document(id=ids_2[1], page_content="bar", metadata={"foo": "bar"}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
|
||||
)
|
||||
def test_default_from_documents(vs_class: type[VectorStore]) -> None:
|
||||
embeddings = FakeEmbeddings(size=1)
|
||||
store = vs_class.from_documents(
|
||||
[Document(id="1", page_content="hello", metadata={"foo": "bar"})], embeddings
|
||||
)
|
||||
|
||||
assert store.get_by_ids(["1"]) == [
|
||||
Document(id="1", page_content="hello", metadata={"foo": "bar"})
|
||||
]
|
||||
|
||||
# from_documents with ids in args
|
||||
store = vs_class.from_documents(
|
||||
[Document(page_content="hello", metadata={"foo": "bar"})], embeddings, ids=["1"]
|
||||
)
|
||||
|
||||
assert store.get_by_ids(["1"]) == [
|
||||
Document(id="1", page_content="hello", metadata={"foo": "bar"})
|
||||
]
|
||||
|
||||
# Test from_documents with id specified in both document and ids
|
||||
original_document = Document(id="7", page_content="baz")
|
||||
store = vs_class.from_documents([original_document], embeddings, ids=["6"])
|
||||
assert original_document.id == "7" # original document should not be modified
|
||||
assert store.get_by_ids(["6"]) == [Document(id="6", page_content="baz")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
|
||||
)
|
||||
async def test_default_afrom_documents(vs_class: type[VectorStore]) -> None:
|
||||
embeddings = FakeEmbeddings(size=1)
|
||||
store = await vs_class.afrom_documents(
|
||||
[Document(id="1", page_content="hello", metadata={"foo": "bar"})], embeddings
|
||||
)
|
||||
|
||||
assert await store.aget_by_ids(["1"]) == [
|
||||
Document(id="1", page_content="hello", metadata={"foo": "bar"})
|
||||
]
|
||||
|
||||
# from_documents with ids in args
|
||||
store = await vs_class.afrom_documents(
|
||||
[Document(page_content="hello", metadata={"foo": "bar"})], embeddings, ids=["1"]
|
||||
)
|
||||
|
||||
assert await store.aget_by_ids(["1"]) == [
|
||||
Document(id="1", page_content="hello", metadata={"foo": "bar"})
|
||||
]
|
||||
|
||||
# Test afrom_documents with id specified in both document and ids
|
||||
original_document = Document(id="7", page_content="baz")
|
||||
store = await vs_class.afrom_documents([original_document], embeddings, ids=["6"])
|
||||
assert original_document.id == "7" # original document should not be modified
|
||||
assert await store.aget_by_ids(["6"]) == [Document(id="6", page_content="baz")]
|
||||
|
Reference in New Issue
Block a user