cli[patch]: implement minimal starter vector store (#28577)

Basically the same as core's in-memory vector store. Removed some
optional methods.
This commit is contained in:
ccurme 2024-12-06 13:10:22 -05:00 committed by GitHub
parent 5277a021c1
commit f3dc142d3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 208 additions and 83 deletions

View File

@ -45,5 +45,4 @@ _e2e_test:
poetry run pip install -e ../../../standard-tests && \
make format lint tests && \
poetry install --with test_integration && \
rm tests/integration_tests/test_vectorstores.py && \
make integration_test

View File

@ -2,23 +2,23 @@
from __future__ import annotations
import uuid
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.vectorstores.utils import _cosine_similarity as cosine_similarity
VST = TypeVar("VST", bound=VectorStore)
@ -158,40 +158,184 @@ class __ModuleName__VectorStore(VectorStore):
""" # noqa: E501
_database: dict[str, tuple[Document, list[float]]] = {}
def __init__(self, embedding: Embeddings) -> None:
"""Initialize with the given embedding function.
def add_texts(
self,
texts: Iterable[str],
Args:
embedding: embedding function to use.
"""
self._database: dict[str, dict[str, Any]] = {}
self.embedding = embedding
@classmethod
def from_texts(
cls: Type[__ModuleName__VectorStore],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
raise NotImplementedError
) -> __ModuleName__VectorStore:
store = cls(
embedding=embedding,
)
store.add_texts(texts=texts, metadatas=metadatas, **kwargs)
return store
# optional: add custom async implementations
# async def aadd_texts(
# self,
# texts: Iterable[str],
# @classmethod
# async def afrom_texts(
# cls: Type[VST],
# texts: List[str],
# embedding: Embeddings,
# metadatas: Optional[List[dict]] = None,
# **kwargs: Any,
# ) -> List[str]:
# ) -> VST:
# return await asyncio.get_running_loop().run_in_executor(
# None, partial(self.add_texts, **kwargs), texts, metadatas
# None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
# )
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
raise NotImplementedError
@property
def embeddings(self) -> Embeddings:
return self.embedding
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Add documents to the store."""
texts = [doc.page_content for doc in documents]
vectors = self.embedding.embed_documents(texts)
if ids and len(ids) != len(texts):
msg = (
f"ids must be the same length as texts. "
f"Got {len(ids)} ids and {len(texts)} texts."
)
raise ValueError(msg)
id_iterator: Iterator[Optional[str]] = (
iter(ids) if ids else iter(doc.id for doc in documents)
)
ids_ = []
for doc, vector in zip(documents, vectors):
doc_id = next(id_iterator)
doc_id_ = doc_id if doc_id else str(uuid.uuid4())
ids_.append(doc_id_)
self._database[doc_id_] = {
"id": doc_id_,
"vector": vector,
"text": doc.page_content,
"metadata": doc.metadata,
}
return ids_
# optional: add custom async implementations
# async def aadd_documents(
# self,
# documents: List[Document],
# ids: Optional[List[str]] = None,
# **kwargs: Any,
# ) -> List[str]:
# raise NotImplementedError
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
if ids:
for _id in ids:
self._database.pop(_id, None)
# optional: add custom async implementations
# async def adelete(
# self, ids: Optional[List[str]] = None, **kwargs: Any
# ) -> Optional[bool]:
# ) -> None:
# raise NotImplementedError
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their ids.
Args:
ids: The ids of the documents to get.
Returns:
A list of Document objects.
"""
documents = []
for doc_id in ids:
doc = self._database.get(doc_id)
if doc:
documents.append(
Document(
id=doc["id"],
page_content=doc["text"],
metadata=doc["metadata"],
)
)
return documents
# optional: add custom async implementations
# async def aget_by_ids(self, ids: Sequence[str], /) -> list[Document]:
# raise NotImplementedError
# NOTE: the below helper method implements similarity search for in-memory
# storage. It is optional and not a part of the vector store interface.
def _similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Callable[[Document], bool]] = None,
**kwargs: Any,
) -> List[tuple[Document, float, List[float]]]:
# get all docs with fixed order in list
docs = list(self._database.values())
if filter is not None:
docs = [
doc
for doc in docs
if filter(Document(page_content=doc["text"], metadata=doc["metadata"]))
]
if not docs:
return []
similarity = cosine_similarity([embedding], [doc["vector"] for doc in docs])[0]
# get the indices ordered by similarity score
top_k_idx = similarity.argsort()[::-1][:k]
return [
(
# Document
Document(
id=doc_dict["id"],
page_content=doc_dict["text"],
metadata=doc_dict["metadata"],
),
# Score
float(similarity[idx].item()),
# Embedding vector
doc_dict["vector"],
)
for idx in top_k_idx
# Assign using walrus operator to avoid multiple lookups
if (doc_dict := docs[idx])
]
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
raise NotImplementedError
embedding = self.embedding.embed_query(query)
return [
doc
for doc, _, _ in self._similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
]
# optional: add custom async implementations
# async def asimilarity_search(
@ -204,9 +348,15 @@ class __ModuleName__VectorStore(VectorStore):
# return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_with_score(
self, *args: Any, **kwargs: Any
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
raise NotImplementedError
embedding = self.embedding.embed_query(query)
return [
(doc, similarity)
for doc, similarity, _ in self._similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
]
# optional: add custom async implementations
# async def asimilarity_search_with_score(
@ -218,10 +368,12 @@ class __ModuleName__VectorStore(VectorStore):
# func = partial(self.similarity_search_with_score, *args, **kwargs)
# return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
raise NotImplementedError
### ADDITIONAL OPTIONAL SEARCH METHODS BELOW ###
# def similarity_search_by_vector(
# self, embedding: List[float], k: int = 4, **kwargs: Any
# ) -> List[Document]:
# raise NotImplementedError
# optional: add custom async implementations
# async def asimilarity_search_by_vector(
@ -233,15 +385,15 @@ class __ModuleName__VectorStore(VectorStore):
# func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
# return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
# def max_marginal_relevance_search(
# self,
# query: str,
# k: int = 4,
# fetch_k: int = 20,
# lambda_mult: float = 0.5,
# **kwargs: Any,
# ) -> List[Document]:
# raise NotImplementedError
# optional: add custom async implementations
# async def amax_marginal_relevance_search(
@ -265,15 +417,15 @@ class __ModuleName__VectorStore(VectorStore):
# )
# return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
# def max_marginal_relevance_search_by_vector(
# self,
# embedding: List[float],
# k: int = 4,
# fetch_k: int = 20,
# lambda_mult: float = 0.5,
# **kwargs: Any,
# ) -> List[Document]:
# raise NotImplementedError
# optional: add custom async implementations
# async def amax_marginal_relevance_search_by_vector(
@ -285,29 +437,3 @@ class __ModuleName__VectorStore(VectorStore):
# **kwargs: Any,
# ) -> List[Document]:
# raise NotImplementedError
@classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
raise NotImplementedError
# optional: add custom async implementations
# @classmethod
# async def afrom_texts(
# cls: Type[VST],
# texts: List[str],
# embedding: Embeddings,
# metadatas: Optional[List[dict]] = None,
# **kwargs: Any,
# ) -> VST:
# return await asyncio.get_running_loop().run_in_executor(
# None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
# )
def _select_relevance_score_fn(self) -> Callable[[float], float]:
raise NotImplementedError

View File

@ -13,7 +13,7 @@ class Test__ModuleName__VectorStoreSync(ReadWriteTestSuite):
@pytest.fixture()
def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
"""Get an empty vectorstore for unit tests."""
store = __ModuleName__VectorStore()
store = __ModuleName__VectorStore(self.get_embeddings())
# note: store should be EMPTY at this point
# if you need to delete data, you may do so here
try:
@ -27,7 +27,7 @@ class Test__ModuleName__VectorStoreAsync(AsyncReadWriteTestSuite):
@pytest.fixture()
async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore
"""Get an empty vectorstore for unit tests."""
store = __ModuleName__VectorStore()
store = __ModuleName__VectorStore(self.get_embeddings())
# note: store should be EMPTY at this point
# if you need to delete data, you may do so here
try:

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -996,4 +996,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
content-hash = "430e4bd69a7e58e26703886697a3c1d9965a29caa2afc6eb7d8efc78164f2730"
content-hash = "a7c5efccaeae83ff262999a4b17048a6a3b29b29ee4e16ec9d6ab14b4cf4d21b"

View File

@ -26,6 +26,14 @@ httpx = ">=0.25.0,<1"
syrupy = "^4"
pytest-socket = ">=0.6.0,<1"
[[tool.poetry.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
[[tool.poetry.dependencies.numpy]]
version = ">=1.26.2,<3"
python = ">=3.12"
[tool.ruff.lint]
select = ["E", "F", "I", "T201"]
@ -55,14 +63,6 @@ optional = true
[tool.poetry.group.test.dependencies]
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.24.0"
python = "<3.12"
[[tool.poetry.group.test.dependencies.numpy]]
version = "^1.26.0"
python = ">=3.12"
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.codespell.dependencies]