mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 02:13:23 +00:00
Add mmr support to redis retriever (#10556)
This commit is contained in:
parent
ccf71e23e8
commit
7f3f6097e7
@ -158,7 +158,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -178,7 +178,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -242,7 +242,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 7,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
@ -253,7 +253,7 @@
|
|||||||
"rds = Redis.from_texts(\n",
|
"rds = Redis.from_texts(\n",
|
||||||
" texts,\n",
|
" texts,\n",
|
||||||
" embeddings,\n",
|
" embeddings,\n",
|
||||||
" metadatas=metadats,\n",
|
" metadatas=metadata,\n",
|
||||||
" redis_url=\"redis://localhost:6379\",\n",
|
" redis_url=\"redis://localhost:6379\",\n",
|
||||||
" index_name=\"users\"\n",
|
" index_name=\"users\"\n",
|
||||||
")"
|
")"
|
||||||
@ -597,7 +597,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -607,7 +607,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 11,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -1110,6 +1110,38 @@
|
|||||||
"retriever.get_relevant_documents(\"foo\")"
|
"retriever.get_relevant_documents(\"foo\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = rds.as_retriever(search_type=\"mmr\", search_kwargs={\"fetch_k\": 20, \"k\": 4, \"lambda_mult\": 0.1})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Document(page_content='foo', metadata={'id': 'doc:users:8f6b673b390647809d510112cde01a27', 'user': 'john', 'job': 'engineer', 'credit_score': 'high', 'age': '18'}),\n",
|
||||||
|
" Document(page_content='bar', metadata={'id': 'doc:users:93521560735d42328b48c9c6f6418d6a', 'user': 'tyler', 'job': 'engineer', 'credit_score': 'high', 'age': '100'}),\n",
|
||||||
|
" Document(page_content='foo', metadata={'id': 'doc:users:125ecd39d07845eabf1a699d44134a5b', 'user': 'nancy', 'job': 'doctor', 'credit_score': 'high', 'age': '94'}),\n",
|
||||||
|
" Document(page_content='foo', metadata={'id': 'doc:users:d6200ab3764c466082fde3eaab972a2a', 'user': 'derrick', 'job': 'doctor', 'credit_score': 'low', 'age': '45'})]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"retriever.get_relevant_documents(\"foo\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -1227,7 +1259,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.3"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1425,6 +1425,7 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
"similarity",
|
"similarity",
|
||||||
"similarity_distance_threshold",
|
"similarity_distance_threshold",
|
||||||
"similarity_score_threshold",
|
"similarity_score_threshold",
|
||||||
|
"mmr",
|
||||||
]
|
]
|
||||||
"""Allowed search types."""
|
"""Allowed search types."""
|
||||||
|
|
||||||
@ -1438,7 +1439,6 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||||
|
|
||||||
elif self.search_type == "similarity_distance_threshold":
|
elif self.search_type == "similarity_distance_threshold":
|
||||||
if self.search_kwargs["distance_threshold"] is None:
|
if self.search_kwargs["distance_threshold"] is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1454,6 +1454,10 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
docs = [doc for doc, _ in docs_and_similarities]
|
docs = [doc for doc, _ in docs_and_similarities]
|
||||||
|
elif self.search_type == "mmr":
|
||||||
|
docs = self.vectorstore.max_marginal_relevance_search(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||||
return docs
|
return docs
|
||||||
|
Loading…
Reference in New Issue
Block a user