core[minor]: Improve support for id in VectorStore (#26660)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-10-07 21:01:08 +02:00 committed by GitHub
parent 931ce8d026
commit c4ebccfec2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 178 additions and 27 deletions

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from langchain_core.documents import Document from langchain_core.documents import Document
@ -42,7 +42,7 @@ class USearch(VectorStore):
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[Dict]] = None, metadatas: Optional[List[Dict]] = None,
ids: Optional[np.ndarray] = None, ids: Optional[Union[np.ndarray, list[str]]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
@ -69,6 +69,8 @@ class USearch(VectorStore):
last_id = int(self.ids[-1]) + 1 last_id = int(self.ids[-1]) + 1
if ids is None: if ids is None:
ids = np.array([str(last_id + id) for id, _ in enumerate(texts)]) ids = np.array([str(last_id + id) for id, _ in enumerate(texts)])
elif isinstance(ids, list):
ids = np.array(ids)
self.index.add(np.array(ids), np.array(embeddings)) self.index.add(np.array(ids), np.array(embeddings))
self.docstore.add(dict(zip(ids, documents))) self.docstore.add(dict(zip(ids, documents)))
@ -134,7 +136,7 @@ class USearch(VectorStore):
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[Dict]] = None, metadatas: Optional[List[Dict]] = None,
ids: Optional[np.ndarray] = None, ids: Optional[Union[np.ndarray, list[str]]] = None,
metric: str = "cos", metric: str = "cos",
**kwargs: Any, **kwargs: Any,
) -> USearch: ) -> USearch:
@ -159,6 +161,8 @@ class USearch(VectorStore):
documents: List[Document] = [] documents: List[Document] = []
if ids is None: if ids is None:
ids = np.array([str(id) for id, _ in enumerate(texts)]) ids = np.array([str(id) for id, _ in enumerate(texts)])
elif isinstance(ids, list):
ids = np.array(ids)
for i, text in enumerate(texts): for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {} metadata = metadatas[i] if metadatas else {}
documents.append(Document(page_content=text, metadata=metadata)) documents.append(Document(page_content=text, metadata=metadata))

View File

@ -25,7 +25,7 @@ import logging
import math import math
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Collection, Iterable, Sequence from collections.abc import Collection, Iterable, Iterator, Sequence
from itertools import cycle from itertools import cycle
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -61,10 +61,8 @@ class VectorStore(ABC):
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[list[dict]] = None, metadatas: Optional[list[dict]] = None,
# One of the kwargs should be `ids` which is a list of ids *,
# associated with the texts. ids: Optional[list[str]] = None,
# This is not yet enforced in the type signature for backwards compatibility
# with existing implementations.
**kwargs: Any, **kwargs: Any,
) -> list[str]: ) -> list[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
@ -72,6 +70,7 @@ class VectorStore(ABC):
Args: Args:
texts: Iterable of strings to add to the vectorstore. texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts. metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of IDs associated with the texts.
**kwargs: vectorstore specific parameters. **kwargs: vectorstore specific parameters.
One of the kwargs should be `ids` which is a list of ids One of the kwargs should be `ids` which is a list of ids
associated with the texts. associated with the texts.
@ -99,10 +98,14 @@ class VectorStore(ABC):
f"Got {len(metadatas)} metadatas and {len(texts_)} texts." f"Got {len(metadatas)} metadatas and {len(texts_)} texts."
) )
metadatas_ = iter(metadatas) if metadatas else cycle([{}]) metadatas_ = iter(metadatas) if metadatas else cycle([{}])
ids_: Iterator[Optional[str]] = iter(ids) if ids else cycle([None])
docs = [ docs = [
Document(page_content=text, metadata=metadata_) Document(id=id_, page_content=text, metadata=metadata_)
for text, metadata_ in zip(texts, metadatas_) for text, metadata_, id_ in zip(texts, metadatas_, ids_)
] ]
if ids is not None:
# For backward compatibility
kwargs["ids"] = ids
return self.add_documents(docs, **kwargs) return self.add_documents(docs, **kwargs)
raise NotImplementedError( raise NotImplementedError(
@ -206,6 +209,8 @@ class VectorStore(ABC):
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[list[dict]] = None, metadatas: Optional[list[dict]] = None,
*,
ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> list[str]: ) -> list[str]:
"""Async run more texts through the embeddings and add to the vectorstore. """Async run more texts through the embeddings and add to the vectorstore.
@ -214,6 +219,7 @@ class VectorStore(ABC):
texts: Iterable of strings to add to the vectorstore. texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts. metadatas: Optional list of metadatas associated with the texts.
Default is None. Default is None.
ids: Optional list
**kwargs: vectorstore specific parameters. **kwargs: vectorstore specific parameters.
Returns: Returns:
@ -223,6 +229,9 @@ class VectorStore(ABC):
ValueError: If the number of metadatas does not match the number of texts. ValueError: If the number of metadatas does not match the number of texts.
ValueError: If the number of ids does not match the number of texts. ValueError: If the number of ids does not match the number of texts.
""" """
if ids is not None:
# For backward compatibility
kwargs["ids"] = ids
if type(self).aadd_documents != VectorStore.aadd_documents: if type(self).aadd_documents != VectorStore.aadd_documents:
# Import document in local scope to avoid circular imports # Import document in local scope to avoid circular imports
from langchain_core.documents import Document from langchain_core.documents import Document
@ -239,12 +248,12 @@ class VectorStore(ABC):
f"Got {len(metadatas)} metadatas and {len(texts_)} texts." f"Got {len(metadatas)} metadatas and {len(texts_)} texts."
) )
metadatas_ = iter(metadatas) if metadatas else cycle([{}]) metadatas_ = iter(metadatas) if metadatas else cycle([{}])
ids_: Iterator[Optional[str]] = iter(ids) if ids else cycle([None])
docs = [ docs = [
Document(page_content=text, metadata=metadata_) Document(id=id_, page_content=text, metadata=metadata_)
for text, metadata_ in zip(texts, metadatas_) for text, metadata_, id_ in zip(texts, metadatas_, ids_)
] ]
return await self.aadd_documents(docs, **kwargs) return await self.aadd_documents(docs, **kwargs)
return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs) return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs)
@ -827,6 +836,15 @@ class VectorStore(ABC):
""" """
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents] metadatas = [d.metadata for d in documents]
if "ids" not in kwargs:
ids = [doc.id for doc in documents]
# If there's at least one valid ID, we'll assume that IDs
# should be used.
if any(ids):
kwargs["ids"] = ids
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs) return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
@classmethod @classmethod
@ -848,6 +866,15 @@ class VectorStore(ABC):
""" """
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents] metadatas = [d.metadata for d in documents]
if "ids" not in kwargs:
ids = [doc.id for doc in documents]
# If there's at least one valid ID, we'll assume that IDs
# should be used.
if any(ids):
kwargs["ids"] = ids
return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs) return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs)
@classmethod @classmethod
@ -857,6 +884,8 @@ class VectorStore(ABC):
texts: list[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[list[dict]] = None, metadatas: Optional[list[dict]] = None,
*,
ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
"""Return VectorStore initialized from texts and embeddings. """Return VectorStore initialized from texts and embeddings.
@ -866,6 +895,7 @@ class VectorStore(ABC):
embedding: Embedding function to use. embedding: Embedding function to use.
metadatas: Optional list of metadatas associated with the texts. metadatas: Optional list of metadatas associated with the texts.
Default is None. Default is None.
ids: Optional list of IDs associated with the texts.
kwargs: Additional keyword arguments. kwargs: Additional keyword arguments.
Returns: Returns:
@ -878,6 +908,8 @@ class VectorStore(ABC):
texts: list[str], texts: list[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[list[dict]] = None, metadatas: Optional[list[dict]] = None,
*,
ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> VST: ) -> VST:
"""Async return VectorStore initialized from texts and embeddings. """Async return VectorStore initialized from texts and embeddings.
@ -887,11 +919,14 @@ class VectorStore(ABC):
embedding: Embedding function to use. embedding: Embedding function to use.
metadatas: Optional list of metadatas associated with the texts. metadatas: Optional list of metadatas associated with the texts.
Default is None. Default is None.
ids: Optional list of IDs associated with the texts.
kwargs: Additional keyword arguments. kwargs: Additional keyword arguments.
Returns: Returns:
VectorStore: VectorStore initialized from texts and embeddings. VectorStore: VectorStore initialized from texts and embeddings.
""" """
if ids is not None:
kwargs["ids"] = ids
return await run_in_executor( return await run_in_executor(
None, cls.from_texts, texts, embedding, metadatas, **kwargs None, cls.from_texts, texts, embedding, metadatas, **kwargs
) )

View File

@ -10,8 +10,10 @@ import uuid
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import Any, Optional from typing import Any, Optional
import pytest
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings, FakeEmbeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
@ -25,10 +27,6 @@ class CustomAddTextsVectorstore(VectorStore):
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[list[dict]] = None, metadatas: Optional[list[dict]] = None,
# One of the kwargs should be `ids` which is a list of ids
# associated with the texts.
# This is not yet enforced in the type signature for backwards compatibility
# with existing implementations.
ids: Optional[list[str]] = None, ids: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> list[str]: ) -> list[str]:
@ -68,12 +66,59 @@ class CustomAddTextsVectorstore(VectorStore):
raise NotImplementedError() raise NotImplementedError()
def test_default_add_documents() -> None: class CustomAddDocumentsVectorstore(VectorStore):
"""A vectorstore that only implements add documents."""
def __init__(self) -> None:
self.store: dict[str, Document] = {}
def add_documents(
self,
documents: list[Document],
*,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
ids_ = []
ids_iter = iter(ids or [])
for document in documents:
id_ = next(ids_iter) if ids else document.id or str(uuid.uuid4())
self.store[id_] = Document(
id=id_, page_content=document.page_content, metadata=document.metadata
)
ids_.append(id_)
return ids_
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
return [self.store[id] for id in ids if id in self.store]
@classmethod
def from_texts( # type: ignore
cls,
texts: list[str],
embedding: Embeddings,
metadatas: Optional[list[dict]] = None,
**kwargs: Any,
) -> CustomAddDocumentsVectorstore:
vectorstore = CustomAddDocumentsVectorstore()
vectorstore.add_texts(texts, metadatas=metadatas, **kwargs)
return vectorstore
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> list[Document]:
raise NotImplementedError()
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
def test_default_add_documents(vs_class: type[VectorStore]) -> None:
"""Test that we can implement the upsert method of the CustomVectorStore """Test that we can implement the upsert method of the CustomVectorStore
class without violating the Liskov Substitution Principle. class without violating the Liskov Substitution Principle.
""" """
store = CustomAddTextsVectorstore() store = vs_class()
# Check upsert with id # Check upsert with id
assert store.add_documents([Document(id="1", page_content="hello")]) == ["1"] assert store.add_documents([Document(id="1", page_content="hello")]) == ["1"]
@ -95,8 +140,11 @@ def test_default_add_documents() -> None:
assert store.get_by_ids(["6"]) == [Document(id="6", page_content="baz")] assert store.get_by_ids(["6"]) == [Document(id="6", page_content="baz")]
def test_default_add_texts() -> None: @pytest.mark.parametrize(
store = CustomAddTextsVectorstore() "vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
def test_default_add_texts(vs_class: type[VectorStore]) -> None:
store = vs_class()
# Check that default implementation of add_texts works # Check that default implementation of add_texts works
assert store.add_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"] assert store.add_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"]
@ -122,9 +170,12 @@ def test_default_add_texts() -> None:
] ]
async def test_default_aadd_documents() -> None: @pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
async def test_default_aadd_documents(vs_class: type[VectorStore]) -> None:
"""Test delegation to the synchronous method.""" """Test delegation to the synchronous method."""
store = CustomAddTextsVectorstore() store = vs_class()
# Check upsert with id # Check upsert with id
assert await store.aadd_documents([Document(id="1", page_content="hello")]) == ["1"] assert await store.aadd_documents([Document(id="1", page_content="hello")]) == ["1"]
@ -146,10 +197,13 @@ async def test_default_aadd_documents() -> None:
assert await store.aget_by_ids(["6"]) == [Document(id="6", page_content="baz")] assert await store.aget_by_ids(["6"]) == [Document(id="6", page_content="baz")]
async def test_default_aadd_texts() -> None: @pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
async def test_default_aadd_texts(vs_class: type[VectorStore]) -> None:
"""Test delegation to the synchronous method.""" """Test delegation to the synchronous method."""
store = CustomAddTextsVectorstore() store = vs_class()
# Check that default implementation of add_texts works # Check that default implementation of aadd_texts works
assert await store.aadd_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"] assert await store.aadd_texts(["hello", "world"], ids=["3", "4"]) == ["3", "4"]
assert await store.aget_by_ids(["3", "4"]) == [ assert await store.aget_by_ids(["3", "4"]) == [
@ -172,3 +226,61 @@ async def test_default_aadd_texts() -> None:
Document(id=ids_2[0], page_content="foo", metadata={"foo": "bar"}), Document(id=ids_2[0], page_content="foo", metadata={"foo": "bar"}),
Document(id=ids_2[1], page_content="bar", metadata={"foo": "bar"}), Document(id=ids_2[1], page_content="bar", metadata={"foo": "bar"}),
] ]
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
def test_default_from_documents(vs_class: type[VectorStore]) -> None:
embeddings = FakeEmbeddings(size=1)
store = vs_class.from_documents(
[Document(id="1", page_content="hello", metadata={"foo": "bar"})], embeddings
)
assert store.get_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# from_documents with ids in args
store = vs_class.from_documents(
[Document(page_content="hello", metadata={"foo": "bar"})], embeddings, ids=["1"]
)
assert store.get_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# Test from_documents with id specified in both document and ids
original_document = Document(id="7", page_content="baz")
store = vs_class.from_documents([original_document], embeddings, ids=["6"])
assert original_document.id == "7" # original document should not be modified
assert store.get_by_ids(["6"]) == [Document(id="6", page_content="baz")]
@pytest.mark.parametrize(
"vs_class", [CustomAddTextsVectorstore, CustomAddDocumentsVectorstore]
)
async def test_default_afrom_documents(vs_class: type[VectorStore]) -> None:
embeddings = FakeEmbeddings(size=1)
store = await vs_class.afrom_documents(
[Document(id="1", page_content="hello", metadata={"foo": "bar"})], embeddings
)
assert await store.aget_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# from_documents with ids in args
store = await vs_class.afrom_documents(
[Document(page_content="hello", metadata={"foo": "bar"})], embeddings, ids=["1"]
)
assert await store.aget_by_ids(["1"]) == [
Document(id="1", page_content="hello", metadata={"foo": "bar"})
]
# Test afrom_documents with id specified in both document and ids
original_document = Document(id="7", page_content="baz")
store = await vs_class.afrom_documents([original_document], embeddings, ids=["6"])
assert original_document.id == "7" # original document should not be modified
assert await store.aget_by_ids(["6"]) == [Document(id="6", page_content="baz")]