cli[patch]: implement minimal starter vector store (#28577)

Basically the same as core's in-memory vector store. Removed some
optional methods.
This commit is contained in:
ccurme 2024-12-06 13:10:22 -05:00 committed by GitHub
parent 5277a021c1
commit f3dc142d3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 208 additions and 83 deletions

View File

@ -45,5 +45,4 @@ _e2e_test:
poetry run pip install -e ../../../standard-tests && \ poetry run pip install -e ../../../standard-tests && \
make format lint tests && \ make format lint tests && \
poetry install --with test_integration && \ poetry install --with test_integration && \
rm tests/integration_tests/test_vectorstores.py && \
make integration_test make integration_test

View File

@ -2,23 +2,23 @@
from __future__ import annotations from __future__ import annotations
import uuid
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Iterable, Iterator,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
) )
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores.utils import _cosine_similarity as cosine_similarity
if TYPE_CHECKING:
from langchain_core.documents import Document
VST = TypeVar("VST", bound=VectorStore) VST = TypeVar("VST", bound=VectorStore)
@ -158,40 +158,184 @@ class __ModuleName__VectorStore(VectorStore):
""" # noqa: E501 """ # noqa: E501
_database: dict[str, tuple[Document, list[float]]] = {} def __init__(self, embedding: Embeddings) -> None:
"""Initialize with the given embedding function.
def add_texts( Args:
self, embedding: embedding function to use.
texts: Iterable[str], """
self._database: dict[str, dict[str, Any]] = {}
self.embedding = embedding
@classmethod
def from_texts(
cls: Type[__ModuleName__VectorStore],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> __ModuleName__VectorStore:
raise NotImplementedError store = cls(
embedding=embedding,
)
store.add_texts(texts=texts, metadatas=metadatas, **kwargs)
return store
# optional: add custom async implementations # optional: add custom async implementations
# async def aadd_texts( # @classmethod
# self, # async def afrom_texts(
# texts: Iterable[str], # cls: Type[VST],
# texts: List[str],
# embedding: Embeddings,
# metadatas: Optional[List[dict]] = None, # metadatas: Optional[List[dict]] = None,
# **kwargs: Any, # **kwargs: Any,
# ) -> List[str]: # ) -> VST:
# return await asyncio.get_running_loop().run_in_executor( # return await asyncio.get_running_loop().run_in_executor(
# None, partial(self.add_texts, **kwargs), texts, metadatas # None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
# ) # )
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: @property
raise NotImplementedError def embeddings(self) -> Embeddings:
return self.embedding
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = self.embedding.embed_documents(texts)
if ids and len(ids) != len(texts):
msg = (
f"ids must be the same length as texts. "
f"Got {len(ids)} ids and {len(texts)} texts."
)
raise ValueError(msg)
id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents)
)
ids_ = []
for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator)
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
ids_.append(doc_id_)
self._database[doc_id_] = {
"id": doc_id_,
"vector": vector,
"text": doc.page_content,
"metadata": doc.metadata,
}
return ids_
# optional: add custom async implementations
# async def aadd_documents(
# self,
# documents: List[Document],
# ids: Optional[List[str]] = None,
# **kwargs: Any,
# ) -> List[str]:
# raise NotImplementedError
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
if ids:
for _id in ids:
self._database.pop(_id, None)
# optional: add custom async implementations # optional: add custom async implementations
# async def adelete( # async def adelete(
# self, ids: Optional[List[str]] = None, **kwargs: Any # self, ids: Optional[List[str]] = None, **kwargs: Any
# ) -> Optional[bool]: # ) -> None:
# raise NotImplementedError # raise NotImplementedError
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their ids.
Args:
ids: The ids of the documents to get.
Returns:
A list of Document objects.
"""
documents = []
for doc_id in ids:
doc = self._database.get(doc_id)
if doc:
documents.append(
Document(
id=doc["id"],
page_content=doc["text"],
metadata=doc["metadata"],
)
)
return documents
# optional: add custom async implementations
# async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
# raise NotImplementedError
# NOTE: the below helper method implements similarity search for in-memory
# storage. It is optional and not a part of the vector store interface.
def _similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> List[tuple[Document, float, List[float]]]:
# get all docs with fixed order in list
docs = list(self._database.values())
if filter is not None:
docs = [
doc
for doc in docs
if filter(Document(page_content=doc["text"], metadata=doc["metadata"]))
]
if not docs:
return []
similarity = cosine_similarity([embedding], [doc["vector"] for doc in docs])[0]
# get the indices ordered by similarity score
top_k_idx = similarity.argsort()[::-1][:k]
return [
(
# Document
Document(
id=doc_dict["id"],
page_content=doc_dict["text"],
metadata=doc_dict["metadata"],
),
# Score
float(similarity[idx].item()),
# Embedding vector
doc_dict["vector"],
)
for idx in top_k_idx
# Assign using walrus operator to avoid multiple lookups
if (doc_dict := docs[idx])
]
def similarity_search( def similarity_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
raise NotImplementedError embedding = self.embedding.embed_query(query)
return [
doc
for doc, _, _ in self._similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
]
# optional: add custom async implementations # optional: add custom async implementations
# async def asimilarity_search( # async def asimilarity_search(
@ -204,9 +348,15 @@ class __ModuleName__VectorStore(VectorStore):
# return await asyncio.get_event_loop().run_in_executor(None, func) # return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_with_score( def similarity_search_with_score(
self, *args: Any, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
raise NotImplementedError embedding = self.embedding.embed_query(query)
return [
(doc, similarity)
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
]
# optional: add custom async implementations # optional: add custom async implementations
# async def asimilarity_search_with_score( # async def asimilarity_search_with_score(
@ -218,10 +368,12 @@ class __ModuleName__VectorStore(VectorStore):
# func = partial(self.similarity_search_with_score, *args, **kwargs) # func = partial(self.similarity_search_with_score, *args, **kwargs)
# return await asyncio.get_event_loop().run_in_executor(None, func) # return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector( ### ADDITIONAL OPTIONAL SEARCH METHODS BELOW ###
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]: # def similarity_search_by_vector(
raise NotImplementedError # self, embedding: List[float], k: int = 4, **kwargs: Any
# ) -> List[Document]:
# raise NotImplementedError
# optional: add custom async implementations # optional: add custom async implementations
# async def asimilarity_search_by_vector( # async def asimilarity_search_by_vector(
@ -233,15 +385,15 @@ class __ModuleName__VectorStore(VectorStore):
# func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs) # func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
# return await asyncio.get_event_loop().run_in_executor(None, func) # return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search( # def max_marginal_relevance_search(
self, # self,
query: str, # query: str,
k: int = 4, # k: int = 4,
fetch_k: int = 20, # fetch_k: int = 20,
lambda_mult: float = 0.5, # lambda_mult: float = 0.5,
**kwargs: Any, # **kwargs: Any,
) -> List[Document]: # ) -> List[Document]:
raise NotImplementedError # raise NotImplementedError
# optional: add custom async implementations # optional: add custom async implementations
# async def amax_marginal_relevance_search( # async def amax_marginal_relevance_search(
@ -265,15 +417,15 @@ class __ModuleName__VectorStore(VectorStore):
# ) # )
# return await asyncio.get_event_loop().run_in_executor(None, func) # return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector( # def max_marginal_relevance_search_by_vector(
self, # self,
embedding: List[float], # embedding: List[float],
k: int = 4, # k: int = 4,
fetch_k: int = 20, # fetch_k: int = 20,
lambda_mult: float = 0.5, # lambda_mult: float = 0.5,
**kwargs: Any, # **kwargs: Any,
) -> List[Document]: # ) -> List[Document]:
raise NotImplementedError # raise NotImplementedError
# optional: add custom async implementations # optional: add custom async implementations
# async def amax_marginal_relevance_search_by_vector( # async def amax_marginal_relevance_search_by_vector(
@ -285,29 +437,3 @@ class __ModuleName__VectorStore(VectorStore):
# **kwargs: Any, # **kwargs: Any,
# ) -> List[Document]: # ) -> List[Document]:
# raise NotImplementedError # raise NotImplementedError
@classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
raise NotImplementedError
# optional: add custom async implementations
# @classmethod
# async def afrom_texts(
# cls: Type[VST],
# texts: List[str],
# embedding: Embeddings,
# metadatas: Optional[List[dict]] = None,
# **kwargs: Any,
# ) -> VST:
# return await asyncio.get_running_loop().run_in_executor(
# None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
# )
def _select_relevance_score_fn(self) -> Callable[[float], float]:
raise NotImplementedError

View File

@ -13,7 +13,7 @@ class Test__ModuleName__VectorStoreSync(ReadWriteTestSuite):
@pytest.fixture() @pytest.fixture()
def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
"""Get an empty vectorstore for unit tests.""" """Get an empty vectorstore for unit tests."""
store = __ModuleName__VectorStore() store = __ModuleName__VectorStore(self.get_embeddings())
# note: store should be EMPTY at this point # note: store should be EMPTY at this point
# if you need to delete data, you may do so here # if you need to delete data, you may do so here
try: try:
@ -27,7 +27,7 @@ class Test__ModuleName__VectorStoreAsync(AsyncReadWriteTestSuite):
@pytest.fixture() @pytest.fixture()
async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore
"""Get an empty vectorstore for unit tests.""" """Get an empty vectorstore for unit tests."""
store = __ModuleName__VectorStore() store = __ModuleName__VectorStore(self.get_embeddings())
# note: store should be EMPTY at this point # note: store should be EMPTY at this point
# if you need to delete data, you may do so here # if you need to delete data, you may do so here
try: try:

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
@ -996,4 +996,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "430e4bd69a7e58e26703886697a3c1d9965a29caa2afc6eb7d8efc78164f2730" content-hash = "a7c5efccaeae83ff262999a4b17048a6a3b29b29ee4e16ec9d6ab14b4cf4d21b"

View File

@ -26,6 +26,14 @@ httpx = ">=0.25.0,<1"
syrupy = "^4" syrupy = "^4"
pytest-socket = ">=0.6.0,<1" pytest-socket = ">=0.6.0,<1"
[[tool.poetry.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
[[tool.poetry.dependencies.numpy]]
version = ">=1.26.2,<3"
python = ">=3.12"
[tool.ruff.lint] [tool.ruff.lint]
select = ["E", "F", "I", "T201"] select = ["E", "F", "I", "T201"]
@ -55,14 +63,6 @@ optional = true
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.26.0"
python = ">=3.12"
[tool.poetry.group.test_integration.dependencies] [tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.codespell.dependencies] [tool.poetry.group.codespell.dependencies]