mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
RFC: rich retrieval results
This commit is contained in:
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, overload
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import Serializable
|
||||
from langchain_core.load.dump import dumpd
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
@@ -22,8 +24,19 @@ if TYPE_CHECKING:
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
|
||||
class DocumentResult(Serializable):
|
||||
document: Document
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RetrievalResult(Serializable):
|
||||
documents: List[DocumentResult]
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
RetrieverInput = str
|
||||
RetrieverOutput = List[Document]
|
||||
RetrieverOutput = Union[List[Document], RetrievalResult]
|
||||
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
|
||||
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
|
||||
|
||||
@@ -39,14 +52,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
|
||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
vectorizer: Any
|
||||
docs: List[Document]
|
||||
docs: RetrieverOutput
|
||||
tfidf_array: Any
|
||||
k: int = 4
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(self, query: str) -> RetrieverOutput:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
||||
@@ -116,7 +129,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
) -> RetrieverOutput:
|
||||
config = ensure_config(config)
|
||||
return self.get_relevant_documents(
|
||||
input,
|
||||
@@ -131,7 +144,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Document]:
|
||||
) -> RetrieverOutput:
|
||||
config = ensure_config(config)
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
@@ -153,9 +166,21 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
def _get_relevant_results(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> RetrievalResult:
|
||||
"""Get results relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
RetrievalResult
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
) -> RetrieverOutput:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
@@ -170,16 +195,45 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
run_manager=run_manager.get_sync(),
|
||||
)
|
||||
|
||||
@overload
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
docs_only: Literal[True] = True,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
""""""
|
||||
|
||||
@overload
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
docs_only: Literal[False],
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrievalResult:
|
||||
""""""
|
||||
|
||||
def get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
docs_only: bool = True,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> RetrieverOutput:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
@@ -211,19 +265,27 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
if docs_only:
|
||||
_kwargs = kwargs if self._expects_other_args else {}
|
||||
if self._new_arg_supported:
|
||||
result: RetrieverOutput = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **_kwargs
|
||||
)
|
||||
else:
|
||||
result = self._get_relevant_documents(query, **_kwargs)
|
||||
else:
|
||||
result = self._get_relevant_documents(query, **_kwargs)
|
||||
result = self._get_relevant_results(query, **_kwargs)
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
callback_result = (
|
||||
result
|
||||
if isinstance(result, list)
|
||||
else [doc_res.document for doc_res in result.documents]
|
||||
)
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
callback_result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
@@ -237,7 +299,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
) -> RetrieverOutput:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import (
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.retrievers import BaseRetriever, DocumentResult, RetrievalResult
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -650,22 +650,52 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
result = self._get_relevant_results(query, run_manager=run_manager)
|
||||
return [doc_result.document for doc_result in result.documents]
|
||||
|
||||
def _get_relevant_results(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> RetrievalResult:
|
||||
result_metadata = {
|
||||
"query": query,
|
||||
"search_type": self.search_type,
|
||||
"search_kwargs": self.search_kwargs,
|
||||
}
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
docs_and_scores = self.vectorstore.similarity_search_with_score(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
result = RetrievalResult(
|
||||
documents=[
|
||||
DocumentResult(document=doc, metadata={"score": score})
|
||||
for doc, score in docs_and_scores
|
||||
],
|
||||
metadata=result_metadata,
|
||||
)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
docs_and_similarities = (
|
||||
self.vectorstore.similarity_search_with_relevance_scores(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
result = RetrievalResult(
|
||||
documents=[
|
||||
DocumentResult(document=doc, metadata={"relevance_score": score})
|
||||
for doc, score in docs_and_similarities
|
||||
],
|
||||
metadata=result_metadata,
|
||||
)
|
||||
elif self.search_type == "mmr":
|
||||
docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
result = RetrievalResult(
|
||||
documents=[DocumentResult(document=doc) for doc in docs],
|
||||
metadata=result_metadata,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
return result
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
|
||||
Reference in New Issue
Block a user