mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 02:29:17 +00:00
Harrison/new search (#4359)
Co-authored-by: Jiaping(JP) Zhang <vincentzhangv@gmail.com>
This commit is contained in:
parent
545ae8b756
commit
3ce29cb4a6
@ -25,18 +25,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 2,
|
||||
"id": "9fbcc58f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Exiting: Cleaning up .chroma directory\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import FAISS\n",
|
||||
@ -74,6 +66,7 @@
|
||||
"id": "79b783de",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Maximum Marginal Relevance Retrieval\n",
|
||||
"By default, the vectorstore retriever uses similarity search. If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type."
|
||||
]
|
||||
},
|
||||
@ -97,11 +90,42 @@
|
||||
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d958271",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Similarity Score Threshold Retrieval\n",
|
||||
"\n",
|
||||
"You can also a retrieval method that sets a similarity score threshold and only returns documents with a score above that threshold"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d4272ad8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = db.as_retriever(search_type=\"similarity_score_threshold\", search_kwargs={\"score_threshold\": .5})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "438e761d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = retriever.get_relevant_documents(\"what did he say abotu ketanji brown jackson\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c23b7698",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specifying top k\n",
|
||||
"You can also specify search kwargs like `k` to use when doing retrieval."
|
||||
]
|
||||
},
|
||||
@ -171,7 +195,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||
@ -116,6 +117,16 @@ class VectorStore(ABC):
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
docs_and_similarities = self._similarity_search_with_relevance_scores(
|
||||
query, k=k, **kwargs
|
||||
@ -124,10 +135,23 @@ class VectorStore(ABC):
|
||||
similarity < 0.0 or similarity > 1.0
|
||||
for _, similarity in docs_and_similarities
|
||||
):
|
||||
raise ValueError(
|
||||
warnings.warn(
|
||||
"Relevance scores must be between"
|
||||
f" 0 and 1, got {docs_and_similarities}"
|
||||
)
|
||||
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
if score_threshold is not None:
|
||||
docs_and_similarities = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs_and_similarities
|
||||
if similarity >= score_threshold
|
||||
]
|
||||
if len(docs_and_similarities) == 0:
|
||||
warnings.warn(
|
||||
f"No relevant docs were retrieved using the relevance score\
|
||||
threshold {score_threshold}"
|
||||
)
|
||||
return docs_and_similarities
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
@ -324,13 +348,29 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
if search_type not in ("similarity", "mmr"):
|
||||
if search_type not in ("similarity", "similarity_score_threshold", "mmr"):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
if search_type == "similarity_score_threshold":
|
||||
score_threshold = values["search_kwargs"].get("score_threshold")
|
||||
if (score_threshold is None) or (
|
||||
not isinstance(score_threshold, float)
|
||||
):
|
||||
raise ValueError(
|
||||
"`score_threshold` is not specified with a float value(0~1) "
|
||||
"in `search_kwargs`."
|
||||
)
|
||||
return values
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
|
Loading…
Reference in New Issue
Block a user