mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
Add maximal relevance search to SKLearnVectorStore (#5430)
# Add maximal relevance search to SKLearnVectorStore This PR implements the maximum relevance search in SKLearnVectorStore. Twitter handle: jtolgyesi (I submitted also the original implementation of SKLearnVectorStore) ## Before submitting Unit tests are included. Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
8181f9e362
commit
1111f18eb4
@ -14,6 +14,10 @@ from uuid import uuid4
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
DEFAULT_K = 4 # Number of Documents to return.
|
||||||
|
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
|
||||||
|
|
||||||
|
|
||||||
def guard_import(
|
def guard_import(
|
||||||
@ -223,39 +227,127 @@ class SKLearnVectorStore(VectorStore):
|
|||||||
self._neighbors.fit(self._embeddings_np)
|
self._neighbors.fit(self._embeddings_np)
|
||||||
self._neighbors_fitted = True
|
self._neighbors_fitted = True
|
||||||
|
|
||||||
def similarity_search_with_score(
|
def _similarity_index_search_with_score(
|
||||||
self, query: str, *, k: int = 4, **kwargs: Any
|
self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[int, float]]:
|
||||||
|
"""Search k embeddings similar to the query embedding. Returns a list of
|
||||||
|
(index, distance) tuples."""
|
||||||
if not self._neighbors_fitted:
|
if not self._neighbors_fitted:
|
||||||
raise SKLearnVectorStoreException(
|
raise SKLearnVectorStoreException(
|
||||||
"No data was added to SKLearnVectorStore."
|
"No data was added to SKLearnVectorStore."
|
||||||
)
|
)
|
||||||
query_embedding = self._embedding_function.embed_query(query)
|
|
||||||
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
|
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
|
||||||
[query_embedding], n_neighbors=k
|
[query_embedding], n_neighbors=k
|
||||||
)
|
)
|
||||||
res = []
|
return list(zip(neigh_idxs[0], neigh_dists[0]))
|
||||||
for idx, dist in zip(neigh_idxs[0], neigh_dists[0]):
|
|
||||||
_idx = int(idx)
|
def similarity_search_with_score(
|
||||||
metadata = {"id": self._ids[_idx], **self._metadatas[_idx]}
|
self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
|
||||||
doc = Document(page_content=self._texts[_idx], metadata=metadata)
|
) -> List[Tuple[Document, float]]:
|
||||||
res.append((doc, dist))
|
query_embedding = self._embedding_function.embed_query(query)
|
||||||
return res
|
indices_dists = self._similarity_index_search_with_score(
|
||||||
|
query_embedding, k=k, **kwargs
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
Document(
|
||||||
|
page_content=self._texts[idx],
|
||||||
|
metadata={"id": self._ids[idx], **self._metadatas[idx]},
|
||||||
|
),
|
||||||
|
dist,
|
||||||
|
)
|
||||||
|
for idx, dist in indices_dists
|
||||||
|
]
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = DEFAULT_K, **kwargs: Any
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
|
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
|
||||||
return [doc for doc, _ in docs_scores]
|
return [doc for doc, _ in docs_scores]
|
||||||
|
|
||||||
def _similarity_search_with_relevance_scores(
|
def _similarity_search_with_relevance_scores(
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
self, query: str, k: int = DEFAULT_K, **kwargs: Any
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
docs_dists = self.similarity_search_with_score(query=query, k=k, **kwargs)
|
docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
|
||||||
docs, dists = zip(*docs_dists)
|
docs, dists = zip(*docs_dists)
|
||||||
scores = [1 / math.exp(dist) for dist in dists]
|
scores = [1 / math.exp(dist) for dist in dists]
|
||||||
return list(zip(list(docs), scores))
|
return list(zip(list(docs), scores))
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
fetch_k: int = DEFAULT_FETCH_K,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
indices_dists = self._similarity_index_search_with_score(
|
||||||
|
embedding, k=fetch_k, **kwargs
|
||||||
|
)
|
||||||
|
indices, _ = zip(*indices_dists)
|
||||||
|
result_embeddings = self._embeddings_np[indices,]
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
self._np.array(embedding, dtype=self._np.float32),
|
||||||
|
result_embeddings,
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
mmr_indices = [indices[i] for i in mmr_selected]
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
page_content=self._texts[idx],
|
||||||
|
metadata={"id": self._ids[idx], **self._metadatas[idx]},
|
||||||
|
)
|
||||||
|
for idx in mmr_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = DEFAULT_K,
|
||||||
|
fetch_k: int = DEFAULT_FETCH_K,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
if self._embedding_function is None:
|
||||||
|
raise ValueError(
|
||||||
|
"For MMR search, you must specify an embedding function on creation."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = self._embedding_function.embed_query(query)
|
||||||
|
docs = self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding, k, fetch_k, lambda_mul=lambda_mult
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls,
|
||||||
|
@ -11,7 +11,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
|||||||
def test_sklearn() -> None:
|
def test_sklearn() -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
docsearch = SKLearnVectorStore.from_texts(texts, embedding=FakeEmbeddings())
|
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
assert output[0].page_content == "foo"
|
assert output[0].page_content == "foo"
|
||||||
@ -24,7 +24,7 @@ def test_sklearn_with_metadatas() -> None:
|
|||||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
docsearch = SKLearnVectorStore.from_texts(
|
docsearch = SKLearnVectorStore.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embedding=FakeEmbeddings(),
|
FakeEmbeddings(),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
@ -38,7 +38,7 @@ def test_sklearn_with_metadatas_with_scores() -> None:
|
|||||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
docsearch = SKLearnVectorStore.from_texts(
|
docsearch = SKLearnVectorStore.from_texts(
|
||||||
texts,
|
texts,
|
||||||
embedding=FakeEmbeddings(),
|
FakeEmbeddings(),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
output = docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||||
@ -69,8 +69,32 @@ def test_sklearn_with_persistence(tmpdir: Path) -> None:
|
|||||||
|
|
||||||
# Get a new VectorStore from the persisted directory
|
# Get a new VectorStore from the persisted directory
|
||||||
docsearch = SKLearnVectorStore(
|
docsearch = SKLearnVectorStore(
|
||||||
embedding=FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
|
FakeEmbeddings(), persist_path=str(persist_path), serializer="json"
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
assert len(output) == 1
|
assert len(output) == 1
|
||||||
assert output[0].page_content == "foo"
|
assert output[0].page_content == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("numpy", "sklearn")
|
||||||
|
def test_sklearn_mmr() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = SKLearnVectorStore.from_texts(texts, FakeEmbeddings())
|
||||||
|
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
|
||||||
|
assert len(output) == 1
|
||||||
|
assert output[0].page_content == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("numpy", "sklearn")
|
||||||
|
def test_sklearn_mmr_by_vector() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
embeddings = FakeEmbeddings()
|
||||||
|
docsearch = SKLearnVectorStore.from_texts(texts, embeddings)
|
||||||
|
embedded_query = embeddings.embed_query("foo")
|
||||||
|
output = docsearch.max_marginal_relevance_search_by_vector(
|
||||||
|
embedded_query, k=1, fetch_k=3
|
||||||
|
)
|
||||||
|
assert len(output) == 1
|
||||||
|
assert output[0].page_content == "foo"
|
||||||
|
Loading…
Reference in New Issue
Block a user