mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Add maximal relevance search to SKLearnVectorStore (#5430)
# Add maximal relevance search to SKLearnVectorStore This PR implements the maximum relevance search in SKLearnVectorStore. Twitter handle: jtolgyesi (I submitted also the original implementation of SKLearnVectorStore) ## Before submitting Unit tests are included. Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
8181f9e362
commit
1111f18eb4
@ -14,6 +14,10 @@ from uuid import uuid4
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
DEFAULT_K = 4 # Number of Documents to return.
|
||||
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
|
||||
|
||||
|
||||
def guard_import(
|
||||
@ -223,39 +227,127 @@ class SKLearnVectorStore(VectorStore):
|
||||
self._neighbors.fit(self._embeddings_np)
|
||||
self._neighbors_fitted = True
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, *, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
def _similarity_index_search_with_score(
|
||||
self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Search k embeddings similar to the query embedding. Returns a list of
|
||||
(index, distance) tuples."""
|
||||
if not self._neighbors_fitted:
|
||||
raise SKLearnVectorStoreException(
|
||||
"No data was added to SKLearnVectorStore."
|
||||
)
|
||||
query_embedding = self._embedding_function.embed_query(query)
|
||||
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
|
||||
[query_embedding], n_neighbors=k
|
||||
)
|
||||
res = []
|
||||
for idx, dist in zip(neigh_idxs[0], neigh_dists[0]):
|
||||
_idx = int(idx)
|
||||
metadata = {"id": self._ids[_idx], **self._metadatas[_idx]}
|
||||
doc = Document(page_content=self._texts[_idx], metadata=metadata)
|
||||
res.append((doc, dist))
|
||||
return res
|
||||
return list(zip(neigh_idxs[0], neigh_dists[0]))
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
query_embedding = self._embedding_function.embed_query(query)
|
||||
indices_dists = self._similarity_index_search_with_score(
|
||||
query_embedding, k=k, **kwargs
|
||||
)
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=self._texts[idx],
|
||||
metadata={"id": self._ids[idx], **self._metadatas[idx]},
|
||||
),
|
||||
dist,
|
||||
)
|
||||
for idx, dist in indices_dists
|
||||
]
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
self, query: str, k: int = DEFAULT_K, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
|
||||
return [doc for doc, _ in docs_scores]
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
self, query: str, k: int = DEFAULT_K, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
docs_dists = self.similarity_search_with_score(query=query, k=k, **kwargs)
|
||||
docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
|
||||
docs, dists = zip(*docs_dists)
|
||||
scores = [1 / math.exp(dist) for dist in dists]
|
||||
return list(zip(list(docs), scores))
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = DEFAULT_K,
|
||||
fetch_k: int = DEFAULT_FETCH_K,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
indices_dists = self._similarity_index_search_with_score(
|
||||
embedding, k=fetch_k, **kwargs
|
||||
)
|
||||
indices, _ = zip(*indices_dists)
|
||||
result_embeddings = self._embeddings_np[indices,]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
self._np.array(embedding, dtype=self._np.float32),
|
||||
result_embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
mmr_indices = [indices[i] for i in mmr_selected]
|
||||
return [
|
||||
Document(
|
||||
page_content=self._texts[idx],
|
||||
metadata={"id": self._ids[idx], **self._metadatas[idx]},
|
||||
)
|
||||
for idx in mmr_indices
|
||||
]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = DEFAULT_K,
|
||||
fetch_k: int = DEFAULT_FETCH_K,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._embedding_function is None:
|
||||
raise ValueError(
|
||||
"For MMR search, you must specify an embedding function on creation."
|
||||
)
|
||||
|
||||
embedding = self._embedding_function.embed_query(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding, k, fetch_k, lambda_mul=lambda_mult
|
||||
)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
|
@ -11,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
def test_sklearn() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, embedding=FakeEmbeddings())
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
@ -24,7 +24,7 @@ def test_sklearn_with_metadatas() -> None:
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = SKLearnVectorStore.from_texts(
|
||||
texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
@ -38,7 +38,7 @@ def test_sklearn_with_metadatas_with_scores() -> None:
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = SKLearnVectorStore.from_texts(
|
||||
texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||
@ -69,8 +69,32 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None:
|
||||
|
||||
# Get a new VectorStore from the persisted directory
|
||||
docsearch = SKLearnVectorStore(
|
||||
embedding=FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
|
||||
FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn_mmr() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("numpy", "sklearn")
|
||||
def test_sklearn_mmr_by_vector() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
embeddings = FakeEmbeddings()
|
||||
docsearch = SKLearnVectorStore.from_texts(texts, embeddings)
|
||||
embedded_query = embeddings.embed_query("foo")
|
||||
output = docsearch.max_marginal_relevance_search_by_vector(
|
||||
embedded_query, k=1, fetch_k=3
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0].page_content == "foo"
|
||||
|
Loading…
Reference in New Issue
Block a user