mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-04 22:23:50 +00:00
Harrison/new search (#4359)
Co-authored-by: Jiaping(JP) Zhang <vincentzhangv@gmail.com>
This commit is contained in:
parent
545ae8b756
commit
3ce29cb4a6
@ -25,18 +25,10 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 2,
|
||||||
"id": "9fbcc58f",
|
"id": "9fbcc58f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Exiting: Cleaning up .chroma directory\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
"from langchain.vectorstores import FAISS\n",
|
"from langchain.vectorstores import FAISS\n",
|
||||||
@ -74,6 +66,7 @@
|
|||||||
"id": "79b783de",
|
"id": "79b783de",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
|
"## Maximum Marginal Relevance Retrieval\n",
|
||||||
"By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type."
|
"By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -97,11 +90,42 @@
|
|||||||
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2d958271",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Similarity Score Threshold Retrieval\n",
|
||||||
|
"\n",
|
||||||
|
"You can also a retrieval method that sets a similarity score threshold and only returns documents with a score above that threshold"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "d4272ad8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = db.as_retriever(search_type=\"similarity_score_threshold\", search_kwargs={\"score_threshold\": .5})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "438e761d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "c23b7698",
|
"id": "c23b7698",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
|
"## Specifying top k\n",
|
||||||
"You can also specify search kwargs like `k` to use when doing retrieval."
|
"You can also specify search kwargs like `k` to use when doing retrieval."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -171,7 +195,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.3"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||||
@ -116,6 +117,16 @@ class VectorStore(ABC):
|
|||||||
"""Return docs and relevance scores in the range [0, 1].
|
"""Return docs and relevance scores in the range [0, 1].
|
||||||
|
|
||||||
0 is dissimilar, 1 is most similar.
|
0 is dissimilar, 1 is most similar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: input text
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||||
|
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||||
|
filter the resulting set of retrieved docs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Tuples of (doc, similarity_score)
|
||||||
"""
|
"""
|
||||||
docs_and_similarities = self._similarity_search_with_relevance_scores(
|
docs_and_similarities = self._similarity_search_with_relevance_scores(
|
||||||
query, k=k, **kwargs
|
query, k=k, **kwargs
|
||||||
@ -124,10 +135,23 @@ class VectorStore(ABC):
|
|||||||
similarity < 0.0 or similarity > 1.0
|
similarity < 0.0 or similarity > 1.0
|
||||||
for _, similarity in docs_and_similarities
|
for _, similarity in docs_and_similarities
|
||||||
):
|
):
|
||||||
raise ValueError(
|
warnings.warn(
|
||||||
"Relevance scores must be between"
|
"Relevance scores must be between"
|
||||||
f" 0 and 1, got {docs_and_similarities}"
|
f" 0 and 1, got {docs_and_similarities}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
score_threshold = kwargs.get("score_threshold")
|
||||||
|
if score_threshold is not None:
|
||||||
|
docs_and_similarities = [
|
||||||
|
(doc, similarity)
|
||||||
|
for doc, similarity in docs_and_similarities
|
||||||
|
if similarity >= score_threshold
|
||||||
|
]
|
||||||
|
if len(docs_and_similarities) == 0:
|
||||||
|
warnings.warn(
|
||||||
|
f"No relevant docs were retrieved using the relevance score\
|
||||||
|
threshold {score_threshold}"
|
||||||
|
)
|
||||||
return docs_and_similarities
|
return docs_and_similarities
|
||||||
|
|
||||||
def _similarity_search_with_relevance_scores(
|
def _similarity_search_with_relevance_scores(
|
||||||
@ -324,13 +348,29 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
|||||||
"""Validate search type."""
|
"""Validate search type."""
|
||||||
if "search_type" in values:
|
if "search_type" in values:
|
||||||
search_type = values["search_type"]
|
search_type = values["search_type"]
|
||||||
if search_type not in ("similarity", "mmr"):
|
if search_type not in ("similarity", "similarity_score_threshold", "mmr"):
|
||||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||||
|
if search_type == "similarity_score_threshold":
|
||||||
|
score_threshold = values["search_kwargs"].get("score_threshold")
|
||||||
|
if (score_threshold is None) or (
|
||||||
|
not isinstance(score_threshold, float)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`score_threshold` is not specified with a float value(0~1) "
|
||||||
|
"in `search_kwargs`."
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||||
|
elif self.search_type == "similarity_score_threshold":
|
||||||
|
docs_and_similarities = (
|
||||||
|
self.vectorstore.similarity_search_with_relevance_scores(
|
||||||
|
query, **self.search_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
docs = [doc for doc, _ in docs_and_similarities]
|
||||||
elif self.search_type == "mmr":
|
elif self.search_type == "mmr":
|
||||||
docs = self.vectorstore.max_marginal_relevance_search(
|
docs = self.vectorstore.max_marginal_relevance_search(
|
||||||
query, **self.search_kwargs
|
query, **self.search_kwargs
|
||||||
|
Loading…
Reference in New Issue
Block a user