mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 07:50:39 +00:00
Implemented MMR search for Redis (#10140)
Description: Implemented MMR search for Redis. Pretty straightforward, just using the already implemented MMR method on similarity search–fetched docs. Issue: #10059 Dependencies: None Twitter handle: @hamza_tahboub --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5d8a689d5e
commit
8c0f391815
@ -413,7 +413,8 @@
|
||||
"- ``similarity_search``: Find the most similar vectors to a given vector.\n",
|
||||
"- ``similarity_search_with_score``: Find the most similar vectors to a given vector and return the vector distance\n",
|
||||
"- ``similarity_search_limit_score``: Find the most similar vectors to a given vector and limit the number of results to the ``score_threshold``\n",
|
||||
"- ``similarity_search_with_relevance_scores``: Find the most similar vectors to a given vector and return the vector similarities"
|
||||
"- ``similarity_search_with_relevance_scores``: Find the most similar vectors to a given vector and return the vector similarities\n",
|
||||
"- ``max_marginal_relevance_search``: Find the most similar vectors to a given vector while also optimizing for diversity"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -596,6 +597,26 @@
|
||||
"print(results[0].metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# use maximal marginal relevance search to diversify results\n",
|
||||
"results = rds.max_marginal_relevance_search(\"foo\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# the lambda_mult parameter controls the diversity of the results, the lower the more diverse\n",
|
||||
"results = rds.max_marginal_relevance_search(\"foo\", lambda_mult=0.1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -17,6 +17,10 @@ def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
|
||||
return np.array(array).astype(dtype).tobytes()
|
||||
|
||||
|
||||
def _buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]:
|
||||
return np.frombuffer(buffer, dtype=dtype).tolist()
|
||||
|
||||
|
||||
class TokenEscaper:
|
||||
"""
|
||||
Escape punctuation within an input string.
|
||||
|
@ -17,8 +17,10 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from langchain._api import deprecated
|
||||
@ -30,6 +32,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utilities.redis import (
|
||||
_array_to_buffer,
|
||||
_buffer_to_array,
|
||||
check_redis_module_exist,
|
||||
get_client,
|
||||
)
|
||||
@ -39,6 +42,7 @@ from langchain.vectorstores.redis.constants import (
|
||||
REDIS_REQUIRED_MODULES,
|
||||
REDIS_TAG_SEPARATOR,
|
||||
)
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -803,8 +807,10 @@ class Redis(VectorStore):
|
||||
+ "score_threshold will be removed in a future release.",
|
||||
)
|
||||
|
||||
query_embedding = self._embeddings.embed_query(query)
|
||||
|
||||
redis_query, params_dict = self._prepare_query(
|
||||
query,
|
||||
query_embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
with_metadata=return_metadata,
|
||||
@ -858,13 +864,48 @@ class Redis(VectorStore):
|
||||
Defaults to None.
|
||||
return_metadata (bool, optional): Whether to return metadata.
|
||||
Defaults to True.
|
||||
distance_threshold (Optional[float], optional): Distance threshold
|
||||
for vector distance from query vector. Defaults to None.
|
||||
distance_threshold (Optional[float], optional): Maximum vector distance
|
||||
between selected documents and the query vector. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents that are most similar to the query
|
||||
text.
|
||||
"""
|
||||
query_embedding = self._embeddings.embed_query(query)
|
||||
return self.similarity_search_by_vector(
|
||||
query_embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
return_metadata=return_metadata,
|
||||
distance_threshold=distance_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[RedisFilterExpression] = None,
|
||||
return_metadata: bool = True,
|
||||
distance_threshold: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Run similarity search between a query vector and the indexed vectors.
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The query vector for which to find similar
|
||||
documents.
|
||||
k (int): The number of documents to return. Default is 4.
|
||||
filter (RedisFilterExpression, optional): Optional metadata filter.
|
||||
Defaults to None.
|
||||
return_metadata (bool, optional): Whether to return metadata.
|
||||
Defaults to True.
|
||||
distance_threshold (Optional[float], optional): Maximum vector distance
|
||||
between selected documents and the query vector. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents that are most similar to the query
|
||||
text.
|
||||
"""
|
||||
try:
|
||||
import redis
|
||||
@ -884,7 +925,7 @@ class Redis(VectorStore):
|
||||
)
|
||||
|
||||
redis_query, params_dict = self._prepare_query(
|
||||
query,
|
||||
embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
distance_threshold=distance_threshold,
|
||||
@ -920,6 +961,74 @@ class Redis(VectorStore):
|
||||
)
|
||||
return docs
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[RedisFilterExpression] = None,
|
||||
return_metadata: bool = True,
|
||||
distance_threshold: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query (str): Text to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult (float): Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (RedisFilterExpression, optional): Optional metadata filter.
|
||||
Defaults to None.
|
||||
return_metadata (bool, optional): Whether to return metadata.
|
||||
Defaults to True.
|
||||
distance_threshold (Optional[float], optional): Maximum vector distance
|
||||
between selected documents and the query vector. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
# Embed the query
|
||||
query_embedding = self._embeddings.embed_query(query)
|
||||
|
||||
# Fetch the initial documents
|
||||
prefetch_docs = self.similarity_search_by_vector(
|
||||
query_embedding,
|
||||
k=fetch_k,
|
||||
filter=filter,
|
||||
return_metadata=return_metadata,
|
||||
distance_threshold=distance_threshold,
|
||||
**kwargs,
|
||||
)
|
||||
prefetch_ids = [doc.metadata["id"] for doc in prefetch_docs]
|
||||
|
||||
# Get the embeddings for the fetched documents
|
||||
prefetch_embeddings = [
|
||||
_buffer_to_array(
|
||||
cast(
|
||||
bytes,
|
||||
self.client.hget(prefetch_id, self._schema.content_vector_key),
|
||||
),
|
||||
dtype=self._schema.vector_dtype,
|
||||
)
|
||||
for prefetch_id in prefetch_ids
|
||||
]
|
||||
|
||||
# Select documents using maximal marginal relevance
|
||||
selected_indices = maximal_marginal_relevance(
|
||||
np.array(query_embedding), prefetch_embeddings, lambda_mult=lambda_mult, k=k
|
||||
)
|
||||
selected_docs = [prefetch_docs[i] for i in selected_indices]
|
||||
|
||||
return selected_docs
|
||||
|
||||
def _collect_metadata(self, result: "Document") -> Dict[str, Any]:
|
||||
"""Collect metadata from Redis.
|
||||
|
||||
@ -952,19 +1061,16 @@ class Redis(VectorStore):
|
||||
|
||||
def _prepare_query(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[RedisFilterExpression] = None,
|
||||
distance_threshold: Optional[float] = None,
|
||||
with_metadata: bool = True,
|
||||
with_distance: bool = False,
|
||||
) -> Tuple["Query", Dict[str, Any]]:
|
||||
# Creates embedding vector from user query
|
||||
embedding = self._embeddings.embed_query(query)
|
||||
|
||||
# Creates Redis query
|
||||
params_dict: Dict[str, Union[str, bytes, float]] = {
|
||||
"vector": _array_to_buffer(embedding, self._schema.vector_dtype),
|
||||
"vector": _array_to_buffer(query_embedding, self._schema.vector_dtype),
|
||||
}
|
||||
|
||||
# prepare return fields including score
|
||||
|
@ -187,12 +187,21 @@ def test_redis_filters_1(
|
||||
documents, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=3, filter=filter_expr)
|
||||
sim_output = docsearch.similarity_search("foo", k=3, filter=filter_expr)
|
||||
mmr_output = docsearch.max_marginal_relevance_search(
|
||||
"foo", k=3, fetch_k=5, filter=filter_expr
|
||||
)
|
||||
|
||||
assert len(output) == expected_length
|
||||
assert len(sim_output) == expected_length
|
||||
assert len(mmr_output) == expected_length
|
||||
|
||||
if expected_nums is not None:
|
||||
for out in output:
|
||||
for out in sim_output:
|
||||
assert (
|
||||
out.metadata["text"] in expected_nums
|
||||
or int(out.metadata["num"]) in expected_nums
|
||||
)
|
||||
for out in mmr_output:
|
||||
assert (
|
||||
out.metadata["text"] in expected_nums
|
||||
or int(out.metadata["num"]) in expected_nums
|
||||
@ -302,7 +311,6 @@ def test_similarity_search_limit_distance(texts: List[str]) -> None:
|
||||
|
||||
def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> None:
|
||||
"""Test similarity search with score with limit score."""
|
||||
|
||||
docsearch = Redis.from_texts(
|
||||
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
@ -317,6 +325,32 @@ def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> N
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_max_marginal_relevance_search(texts: List[str]) -> None:
|
||||
"""Test max marginal relevance search."""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=3)
|
||||
sim_output = docsearch.similarity_search(texts[0], k=3)
|
||||
assert mmr_output == sim_output
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=2, fetch_k=3)
|
||||
assert len(mmr_output) == 2
|
||||
assert mmr_output[0].page_content == texts[0]
|
||||
assert mmr_output[1].page_content == texts[1]
|
||||
|
||||
mmr_output = docsearch.max_marginal_relevance_search(
|
||||
texts[0], k=2, fetch_k=3, lambda_mult=0.1 # more diversity
|
||||
)
|
||||
assert len(mmr_output) == 2
|
||||
assert mmr_output[0].page_content == texts[0]
|
||||
assert mmr_output[1].page_content == texts[2]
|
||||
|
||||
# if fetch_k < k, then the output will be less than k
|
||||
mmr_output = docsearch.max_marginal_relevance_search(texts[0], k=3, fetch_k=2)
|
||||
assert len(mmr_output) == 2
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_delete(texts: List[str]) -> None:
|
||||
"""Test deleting a new document"""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
|
Loading…
Reference in New Issue
Block a user