mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
community: Cassandra Vector Store: modernize implementation (#27253)
**Description:** This PR updates `CassandraGraphVectorStore` to be based off `CassandraVectorStore`, instead of using a custom CQL implementation. This allows users using a `CassandraVectorStore` to upgrade to a `GraphVectorStore` without having to change their database schema or re-embed documents. This PR also updates the documentation of the `GraphVectorStore` base class and contains native async implementations for the standard graph methods: `traversal_search` and `mmr_traversal_search` in `CassandraVectorStore`. **Issue:** No issue number. **Dependencies:** https://github.com/langchain-ai/langchain/pull/27078 (already-merged) **Lint and test**: - Lint and tests all pass, including existing `CassandraGraphVectorStore` tests. - Also added numerous additional tests based of the tests in `langchain-astradb` which cover many more scenarios than the existing tests for `Cassandra` and `CassandraGraphVectorStore` ** BREAKING CHANGE** Note that this is a breaking change for existing users of `CassandraGraphVectorStore`. They will need to wipe their database table and restart. However: - The interfaces have not changed. Just the underlying storage mechanism. - Any one using `langchain_community.vectorstores.Cassandra` can instead use `langchain_community.graph_vectorstores.CassandraGraphVectorStore` and they will gain Graph capabilities without having to re-embed their existing documents. This is the primary goal of this PR. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -144,6 +144,7 @@ from langchain_community.graph_vectorstores.cassandra import CassandraGraphVecto
|
||||
from langchain_community.graph_vectorstores.links import (
|
||||
Link,
|
||||
)
|
||||
from langchain_community.graph_vectorstores.mmr_helper import MmrHelper
|
||||
|
||||
__all__ = [
|
||||
"GraphVectorStore",
|
||||
@@ -151,4 +152,5 @@ __all__ = [
|
||||
"Node",
|
||||
"Link",
|
||||
"CassandraGraphVectorStore",
|
||||
"MmrHelper",
|
||||
]
|
||||
|
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterable, Collection, Iterable, Iterator
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
@@ -21,6 +23,8 @@ from pydantic import Field
|
||||
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _has_next(iterator: Iterator) -> bool:
|
||||
"""Checks if the iterator has more elements.
|
||||
@@ -158,6 +162,7 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def aadd_nodes(
|
||||
@@ -169,6 +174,7 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
nodes: the nodes to add.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs))
|
||||
done = object()
|
||||
@@ -186,7 +192,7 @@ class GraphVectorStore(VectorStore):
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
"""Run more texts through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
@@ -214,15 +220,15 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
texts: Iterable of strings to add to the vector store.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
ids: Optional list of IDs associated with the texts.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
**kwargs: vector store specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
List of ids from adding the texts into the vector store.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return list(self.add_nodes(nodes, **kwargs))
|
||||
@@ -235,7 +241,7 @@ class GraphVectorStore(VectorStore):
|
||||
ids: Optional[Iterable[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
"""Run more texts through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the metadata field `links` will be extracted to create
|
||||
the `Node` links.
|
||||
@@ -263,15 +269,15 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
texts: Iterable of strings to add to the vector store.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
The metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
ids: Optional list of IDs associated with the texts.
|
||||
**kwargs: vectorstore specific parameters.
|
||||
**kwargs: vector store specific parameters.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
List of ids from adding the texts into the vector store.
|
||||
"""
|
||||
nodes = _texts_to_nodes(texts, metadatas, ids)
|
||||
return [_id async for _id in self.aadd_nodes(nodes, **kwargs)]
|
||||
@@ -281,7 +287,7 @@ class GraphVectorStore(VectorStore):
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
"""Run more documents through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
@@ -316,7 +322,7 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
documents: Documents to add to the vector store.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
@@ -331,7 +337,7 @@ class GraphVectorStore(VectorStore):
|
||||
documents: Iterable[Document],
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more documents through the embeddings and add to the vectorstore.
|
||||
"""Run more documents through the embeddings and add to the vector store.
|
||||
|
||||
The Links present in the document metadata field `links` will be extracted to
|
||||
create the `Node` links.
|
||||
@@ -366,7 +372,7 @@ class GraphVectorStore(VectorStore):
|
||||
)
|
||||
|
||||
Args:
|
||||
documents: Documents to add to the vectorstore.
|
||||
documents: Documents to add to the vector store.
|
||||
The document's metadata key `links` shall be an iterable of
|
||||
:py:class:`~langchain_community.graph_vectorstores.links.Link`.
|
||||
|
||||
@@ -383,6 +389,7 @@ class GraphVectorStore(VectorStore):
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
@@ -396,8 +403,10 @@ class GraphVectorStore(VectorStore):
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
Collection of retrieved documents.
|
||||
"""
|
||||
|
||||
async def atraversal_search(
|
||||
@@ -406,6 +415,7 @@ class GraphVectorStore(VectorStore):
|
||||
*,
|
||||
k: int = 4,
|
||||
depth: int = 1,
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from traversing this graph store.
|
||||
@@ -419,12 +429,20 @@ class GraphVectorStore(VectorStore):
|
||||
k: The number of Documents to return from the initial search.
|
||||
Defaults to 4. Applies to each of the query strings.
|
||||
depth: The maximum depth of edges to traverse. Defaults to 1.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
Returns:
|
||||
Retrieved documents.
|
||||
Collection of retrieved documents.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None, self.traversal_search, query, k=k, depth=depth, **kwargs
|
||||
None,
|
||||
self.traversal_search,
|
||||
query,
|
||||
k=k,
|
||||
depth=depth,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
done = object()
|
||||
@@ -439,12 +457,14 @@ class GraphVectorStore(VectorStore):
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
initial_roots: Sequence[str] = (),
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
@@ -459,6 +479,10 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
initial_roots: Optional list of document IDs to use for initializing search.
|
||||
The top `adjacent_k` nodes adjacent to each initial root will be
|
||||
included in the set of initial candidates. To fetch only in the
|
||||
neighborhood of these nodes, set `fetch_k = 0`.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
@@ -471,18 +495,22 @@ class GraphVectorStore(VectorStore):
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def ammr_traversal_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
initial_roots: Sequence[str] = (),
|
||||
k: int = 4,
|
||||
depth: int = 2,
|
||||
fetch_k: int = 100,
|
||||
adjacent_k: int = 10,
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = float("-inf"),
|
||||
filter: dict[str, Any] | None = None, # noqa: A002
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[Document]:
|
||||
"""Retrieve documents from this graph store using MMR-traversal.
|
||||
@@ -497,6 +525,10 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
Args:
|
||||
query: The query string to search for.
|
||||
initial_roots: Optional list of document IDs to use for initializing search.
|
||||
The top `adjacent_k` nodes adjacent to each initial root will be
|
||||
included in the set of initial candidates. To fetch only in the
|
||||
neighborhood of these nodes, set `fetch_k = 0`.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch via similarity.
|
||||
Defaults to 100.
|
||||
@@ -509,18 +541,22 @@ class GraphVectorStore(VectorStore):
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to negative infinity.
|
||||
filter: Optional metadata to filter the results.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
iterator = iter(
|
||||
await run_in_executor(
|
||||
None,
|
||||
self.mmr_traversal_search,
|
||||
query,
|
||||
initial_roots=initial_roots,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
adjacent_k=adjacent_k,
|
||||
depth=depth,
|
||||
lambda_mult=lambda_mult,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -544,6 +580,11 @@ class GraphVectorStore(VectorStore):
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
if kwargs.get("depth", 0) > 0:
|
||||
logger.warning(
|
||||
"'mmr' search started with depth > 0. "
|
||||
"Maybe you meant to do a 'mmr_traversal' search?"
|
||||
)
|
||||
return list(
|
||||
self.mmr_traversal_search(
|
||||
query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0
|
||||
@@ -573,7 +614,7 @@ class GraphVectorStore(VectorStore):
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
"'mmr', 'traversal', or 'mmr_traversal'."
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
@@ -590,11 +631,13 @@ class GraphVectorStore(VectorStore):
|
||||
return await self.amax_marginal_relevance_search(query, **kwargs)
|
||||
elif search_type == "traversal":
|
||||
return [doc async for doc in self.atraversal_search(query, **kwargs)]
|
||||
elif search_type == "mmr_traversal":
|
||||
return [doc async for doc in self.ammr_traversal_search(query, **kwargs)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"search_type of {search_type} not allowed. Expected "
|
||||
"search_type to be 'similarity', 'similarity_score_threshold', "
|
||||
"'mmr' or 'traversal'."
|
||||
"'mmr', 'traversal', or 'mmr_traversal'."
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever:
|
||||
@@ -606,13 +649,14 @@ class GraphVectorStore(VectorStore):
|
||||
|
||||
- search_type (Optional[str]): Defines the type of search that
|
||||
the Retriever should perform.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``, or
|
||||
``similarity_score_threshold``.
|
||||
Can be ``traversal`` (default), ``similarity``, ``mmr``,
|
||||
``mmr_traversal``, or ``similarity_score_threshold``.
|
||||
- search_kwargs (Optional[Dict]): Keyword arguments to pass to the
|
||||
search function. Can include things like:
|
||||
|
||||
- k(int): Amount of documents to return (Default: 4).
|
||||
- depth(int): The maximum depth of edges to traverse (Default: 1).
|
||||
Only applies to search_type: ``traversal`` and ``mmr_traversal``.
|
||||
- score_threshold(float): Minimum relevance threshold
|
||||
for similarity_score_threshold.
|
||||
- fetch_k(int): Amount of documents to pass to MMR algorithm
|
||||
@@ -629,21 +673,21 @@ class GraphVectorStore(VectorStore):
|
||||
# Retrieve documents traversing edges
|
||||
docsearch.as_retriever(
|
||||
search_type="traversal",
|
||||
search_kwargs={'k': 6, 'depth': 3}
|
||||
search_kwargs={'k': 6, 'depth': 2}
|
||||
)
|
||||
|
||||
# Retrieve more documents with higher diversity
|
||||
# Retrieve documents with higher diversity
|
||||
# Useful if your dataset has many similar documents
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25}
|
||||
search_type="mmr_traversal",
|
||||
search_kwargs={'k': 6, 'lambda_mult': 0.25, 'depth': 2}
|
||||
)
|
||||
|
||||
# Fetch more documents for the MMR algorithm to consider
|
||||
# But only return the top 5
|
||||
docsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50}
|
||||
search_type="mmr_traversal",
|
||||
search_kwargs={'k': 5, 'fetch_k': 50, 'depth': 2}
|
||||
)
|
||||
|
||||
# Only retrieve documents that have a relevance score
|
||||
@@ -657,7 +701,7 @@ class GraphVectorStore(VectorStore):
|
||||
docsearch.as_retriever(search_kwargs={'k': 1})
|
||||
|
||||
"""
|
||||
return GraphVectorStoreRetriever(vectorstore=self, **kwargs)
|
||||
return GraphVectorStoreRetriever(vector_store=self, **kwargs)
|
||||
|
||||
|
||||
@beta(message="Added in version 0.3.1 of langchain_community. API subject to change.")
|
||||
@@ -744,7 +788,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
Passing search parameters
|
||||
-------------------------
|
||||
|
||||
We can pass parameters to the underlying graph vectorstore's search methods using
|
||||
We can pass parameters to the underlying graph vector store's search methods using
|
||||
``search_kwargs``.
|
||||
|
||||
Specifying graph traversal depth
|
||||
@@ -793,7 +837,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5})
|
||||
""" # noqa: E501
|
||||
|
||||
vectorstore: GraphVectorStore
|
||||
vector_store: GraphVectorStore
|
||||
"""GraphVectorStore to use for retrieval."""
|
||||
search_type: str = "traversal"
|
||||
"""Type of search to perform. Defaults to "traversal"."""
|
||||
@@ -809,10 +853,10 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> list[Document]:
|
||||
if self.search_type == "traversal":
|
||||
return list(self.vectorstore.traversal_search(query, **self.search_kwargs))
|
||||
return list(self.vector_store.traversal_search(query, **self.search_kwargs))
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return list(
|
||||
self.vectorstore.mmr_traversal_search(query, **self.search_kwargs)
|
||||
self.vector_store.mmr_traversal_search(query, **self.search_kwargs)
|
||||
)
|
||||
else:
|
||||
return super()._get_relevant_documents(query, run_manager=run_manager)
|
||||
@@ -823,14 +867,14 @@ class GraphVectorStoreRetriever(VectorStoreRetriever):
|
||||
if self.search_type == "traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.atraversal_search(
|
||||
async for doc in self.vector_store.atraversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
elif self.search_type == "mmr_traversal":
|
||||
return [
|
||||
doc
|
||||
async for doc in self.vectorstore.ammr_traversal_search(
|
||||
async for doc in self.vector_store.ammr_traversal_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
]
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,272 @@
|
||||
"""Tools for the Graph Traversal Maximal Marginal Relevance (MMR) reranking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain_community.utils.math import cosine_similarity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
|
||||
emb_array = np.array(embedding, dtype=np.float32)
|
||||
if emb_array.ndim == 1:
|
||||
emb_array = np.expand_dims(emb_array, axis=0)
|
||||
return emb_array
|
||||
|
||||
|
||||
NEG_INF = float("-inf")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Candidate:
|
||||
id: str
|
||||
similarity: float
|
||||
weighted_similarity: float
|
||||
weighted_redundancy: float
|
||||
score: float = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.score = self.weighted_similarity - self.weighted_redundancy
|
||||
|
||||
def update_redundancy(self, new_weighted_redundancy: float) -> None:
|
||||
if new_weighted_redundancy > self.weighted_redundancy:
|
||||
self.weighted_redundancy = new_weighted_redundancy
|
||||
self.score = self.weighted_similarity - self.weighted_redundancy
|
||||
|
||||
|
||||
class MmrHelper:
|
||||
"""Helper for executing an MMR traversal query.
|
||||
|
||||
Args:
|
||||
query_embedding: The embedding of the query to use for scoring.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding to maximum
|
||||
diversity and 1 to minimum diversity. Defaults to 0.5.
|
||||
score_threshold: Only documents with a score greater than or equal
|
||||
this threshold will be chosen. Defaults to -infinity.
|
||||
"""
|
||||
|
||||
dimensions: int
|
||||
"""Dimensions of the embedding."""
|
||||
|
||||
query_embedding: NDArray[np.float32]
|
||||
"""Embedding of the query as a (1,dim) ndarray."""
|
||||
|
||||
lambda_mult: float
|
||||
"""Number between 0 and 1.
|
||||
|
||||
Determines the degree of diversity among the results with 0 corresponding to
|
||||
maximum diversity and 1 to minimum diversity."""
|
||||
|
||||
lambda_mult_complement: float
|
||||
"""1 - lambda_mult."""
|
||||
|
||||
score_threshold: float
|
||||
"""Only documents with a score greater than or equal to this will be chosen."""
|
||||
|
||||
selected_ids: list[str]
|
||||
"""List of selected IDs (in selection order)."""
|
||||
|
||||
selected_mmr_scores: list[float]
|
||||
"""List of MMR score at the time each document is selected."""
|
||||
|
||||
selected_similarity_scores: list[float]
|
||||
"""List of similarity score for each selected document."""
|
||||
|
||||
selected_embeddings: NDArray[np.float32]
|
||||
"""(N, dim) ndarray with a row for each selected node."""
|
||||
|
||||
candidate_id_to_index: dict[str, int]
|
||||
"""Dictionary of candidate IDs to indices in candidates and candidate_embeddings."""
|
||||
candidates: list[_Candidate]
|
||||
"""List containing information about candidates.
|
||||
|
||||
Same order as rows in `candidate_embeddings`.
|
||||
"""
|
||||
candidate_embeddings: NDArray[np.float32]
|
||||
"""(N, dim) ndarray with a row for each candidate."""
|
||||
|
||||
best_score: float
|
||||
best_id: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k: int,
|
||||
query_embedding: list[float],
|
||||
lambda_mult: float = 0.5,
|
||||
score_threshold: float = NEG_INF,
|
||||
) -> None:
|
||||
"""Create a new Traversal MMR helper."""
|
||||
self.query_embedding = _emb_to_ndarray(query_embedding)
|
||||
self.dimensions = self.query_embedding.shape[1]
|
||||
|
||||
self.lambda_mult = lambda_mult
|
||||
self.lambda_mult_complement = 1 - lambda_mult
|
||||
self.score_threshold = score_threshold
|
||||
|
||||
self.selected_ids = []
|
||||
self.selected_similarity_scores = []
|
||||
self.selected_mmr_scores = []
|
||||
|
||||
# List of selected embeddings (in selection order).
|
||||
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)
|
||||
|
||||
self.candidate_id_to_index = {}
|
||||
|
||||
# List of the candidates.
|
||||
self.candidates = []
|
||||
# numpy n-dimensional array of the candidate embeddings.
|
||||
self.candidate_embeddings = np.ndarray((0, self.dimensions), dtype=np.float32)
|
||||
|
||||
self.best_score = NEG_INF
|
||||
self.best_id = None
|
||||
|
||||
def candidate_ids(self) -> Iterable[str]:
|
||||
"""Return the IDs of the candidates."""
|
||||
return self.candidate_id_to_index.keys()
|
||||
|
||||
def _already_selected_embeddings(self) -> NDArray[np.float32]:
|
||||
"""Return the selected embeddings sliced to the already assigned values."""
|
||||
selected = len(self.selected_ids)
|
||||
return np.vsplit(self.selected_embeddings, [selected])[0]
|
||||
|
||||
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
|
||||
"""Pop the candidate with the given ID.
|
||||
|
||||
Returns:
|
||||
The similarity score and embedding of the candidate.
|
||||
"""
|
||||
# Get the embedding for the id.
|
||||
index = self.candidate_id_to_index.pop(candidate_id)
|
||||
if self.candidates[index].id != candidate_id:
|
||||
msg = (
|
||||
"ID in self.candidate_id_to_index doesn't match the ID of the "
|
||||
"corresponding index in self.candidates"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy()
|
||||
|
||||
# Swap that index with the last index in the candidates and
|
||||
# candidate_embeddings.
|
||||
last_index = self.candidate_embeddings.shape[0] - 1
|
||||
|
||||
similarity = 0.0
|
||||
if index == last_index:
|
||||
# Already the last item. We don't need to swap.
|
||||
similarity = self.candidates.pop().similarity
|
||||
else:
|
||||
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]
|
||||
|
||||
similarity = self.candidates[index].similarity
|
||||
|
||||
old_last = self.candidates.pop()
|
||||
self.candidates[index] = old_last
|
||||
self.candidate_id_to_index[old_last.id] = index
|
||||
|
||||
self.candidate_embeddings = np.vsplit(self.candidate_embeddings, [last_index])[
|
||||
0
|
||||
]
|
||||
|
||||
return similarity, embedding
|
||||
|
||||
def pop_best(self) -> str | None:
|
||||
"""Select and pop the best item being considered.
|
||||
|
||||
Updates the consideration set based on it.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ID of the best item.
|
||||
"""
|
||||
if self.best_id is None or self.best_score < self.score_threshold:
|
||||
return None
|
||||
|
||||
# Get the selection and remove from candidates.
|
||||
selected_id = self.best_id
|
||||
selected_similarity, selected_embedding = self._pop_candidate(selected_id)
|
||||
|
||||
# Add the ID and embedding to the selected information.
|
||||
selection_index = len(self.selected_ids)
|
||||
self.selected_ids.append(selected_id)
|
||||
self.selected_mmr_scores.append(self.best_score)
|
||||
self.selected_similarity_scores.append(selected_similarity)
|
||||
self.selected_embeddings[selection_index] = selected_embedding
|
||||
|
||||
# Reset the best score / best ID.
|
||||
self.best_score = NEG_INF
|
||||
self.best_id = None
|
||||
|
||||
# Update the candidates redundancy, tracking the best node.
|
||||
if self.candidate_embeddings.shape[0] > 0:
|
||||
similarity = cosine_similarity(
|
||||
self.candidate_embeddings, np.expand_dims(selected_embedding, axis=0)
|
||||
)
|
||||
for index, candidate in enumerate(self.candidates):
|
||||
candidate.update_redundancy(similarity[index][0])
|
||||
if candidate.score > self.best_score:
|
||||
self.best_score = candidate.score
|
||||
self.best_id = candidate.id
|
||||
|
||||
return selected_id
|
||||
|
||||
def add_candidates(self, candidates: dict[str, list[float]]) -> None:
|
||||
"""Add candidates to the consideration set."""
|
||||
# Determine the keys to actually include.
|
||||
# These are the candidates that aren't already selected
|
||||
# or under consideration.
|
||||
include_ids_set = set(candidates.keys())
|
||||
include_ids_set.difference_update(self.selected_ids)
|
||||
include_ids_set.difference_update(self.candidate_id_to_index.keys())
|
||||
include_ids = list(include_ids_set)
|
||||
|
||||
# Now, build up a matrix of the remaining candidate embeddings.
|
||||
# And add them to the
|
||||
new_embeddings: NDArray[np.float32] = np.ndarray(
|
||||
(
|
||||
len(include_ids),
|
||||
self.dimensions,
|
||||
)
|
||||
)
|
||||
offset = self.candidate_embeddings.shape[0]
|
||||
for index, candidate_id in enumerate(include_ids):
|
||||
if candidate_id in include_ids:
|
||||
self.candidate_id_to_index[candidate_id] = offset + index
|
||||
embedding = candidates[candidate_id]
|
||||
new_embeddings[index] = embedding
|
||||
|
||||
# Compute the similarity to the query.
|
||||
similarity = cosine_similarity(new_embeddings, self.query_embedding)
|
||||
|
||||
# Compute the distance metrics of all of pairs in the selected set with
|
||||
# the new candidates.
|
||||
redundancy = cosine_similarity(
|
||||
new_embeddings, self._already_selected_embeddings()
|
||||
)
|
||||
for index, candidate_id in enumerate(include_ids):
|
||||
max_redundancy = 0.0
|
||||
if redundancy.shape[0] > 0:
|
||||
max_redundancy = redundancy[index].max()
|
||||
candidate = _Candidate(
|
||||
id=candidate_id,
|
||||
similarity=similarity[index][0],
|
||||
weighted_similarity=self.lambda_mult * similarity[index][0],
|
||||
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
|
||||
)
|
||||
self.candidates.append(candidate)
|
||||
|
||||
if candidate.score >= self.best_score:
|
||||
self.best_score = candidate.score
|
||||
self.best_id = candidate.id
|
||||
|
||||
# Add the new embeddings to the candidate set.
|
||||
self.candidate_embeddings = np.vstack(
|
||||
(
|
||||
self.candidate_embeddings,
|
||||
new_embeddings,
|
||||
)
|
||||
)
|
Reference in New Issue
Block a user