From 617a477fbba111731ba9805937d9a801fb24e5dc Mon Sep 17 00:00:00 2001 From: Bagatur Date: Mon, 22 Jan 2024 12:46:33 -0800 Subject: [PATCH] RFC: rich retrieval results --- libs/core/langchain_core/retrievers.py | 92 ++++++++++++++++++++---- libs/core/langchain_core/vectorstores.py | 38 ++++++++-- 2 files changed, 111 insertions(+), 19 deletions(-) diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index f2156366959..ff9fcc35da6 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -3,10 +3,12 @@ from __future__ import annotations import warnings from abc import ABC, abstractmethod from inspect import signature -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, overload from langchain_core.documents import Document +from langchain_core.load import Serializable from langchain_core.load.dump import dumpd +from langchain_core.pydantic_v1 import Field from langchain_core.runnables import ( Runnable, RunnableConfig, @@ -22,8 +24,19 @@ if TYPE_CHECKING: Callbacks, ) + +class DocumentResult(Serializable): + document: Document + metadata: dict = Field(default_factory=dict) + + +class RetrievalResult(Serializable): + documents: List[DocumentResult] + metadata: dict = Field(default_factory=dict) + + RetrieverInput = str -RetrieverOutput = List[Document] +RetrieverOutput = Union[List[Document], RetrievalResult] RetrieverLike = Runnable[RetrieverInput, RetrieverOutput] RetrieverOutputLike = Runnable[Any, RetrieverOutput] @@ -39,14 +52,14 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): class TFIDFRetriever(BaseRetriever, BaseModel): vectorizer: Any - docs: List[Document] + docs: RetrieverOutput tfidf_array: Any k: int = 4 class Config: arbitrary_types_allowed = True - def get_relevant_documents(self, query: str) -> List[Document]: + def get_relevant_documents(self, query: str) -> RetrieverOutput: from sklearn.metrics.pairwise import cosine_similarity # Ip -- (n_docs,x), Op -- (n_docs,n_Feats) @@ -116,7 +129,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): def invoke( self, input: str, config: Optional[RunnableConfig] = None - ) -> List[Document]: + ) -> RetrieverOutput: config = ensure_config(config) return self.get_relevant_documents( input, @@ -131,7 +144,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): input: str, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], - ) -> List[Document]: + ) -> RetrieverOutput: config = ensure_config(config) return await self.aget_relevant_documents( input, @@ -153,9 +166,21 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): List of relevant documents """ + def _get_relevant_results( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> RetrievalResult: + """Get results relevant to a query. + Args: + query: String to find relevant documents for + run_manager: The callbacks handler to use + Returns: + RetrievalResult + """ + raise NotImplementedError() + async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun - ) -> List[Document]: + ) -> RetrieverOutput: """Asynchronously get documents relevant to a query. Args: query: String to find relevant documents for @@ -170,16 +195,45 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): run_manager=run_manager.get_sync(), ) + @overload def get_relevant_documents( self, query: str, *, + docs_only: Literal[True] = True, callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, **kwargs: Any, ) -> List[Document]: + """""" + + @overload + def get_relevant_documents( + self, + query: str, + *, + docs_only: Literal[False], + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> RetrievalResult: + """""" + + def get_relevant_documents( + self, + query: str, + *, + docs_only: bool = True, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, + **kwargs: Any, + ) -> RetrieverOutput: """Retrieve documents relevant to a query. Args: query: string to find relevant documents for @@ -211,19 +265,27 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): **kwargs, ) try: - _kwargs = kwargs if self._expects_other_args else {} - if self._new_arg_supported: - result = self._get_relevant_documents( - query, run_manager=run_manager, **_kwargs - ) + if docs_only: + _kwargs = kwargs if self._expects_other_args else {} + if self._new_arg_supported: + result: RetrieverOutput = self._get_relevant_documents( + query, run_manager=run_manager, **_kwargs + ) + else: + result = self._get_relevant_documents(query, **_kwargs) else: - result = self._get_relevant_documents(query, **_kwargs) + result = self._get_relevant_results(query, **_kwargs) except Exception as e: run_manager.on_retriever_error(e) raise e else: + callback_result = ( + result + if isinstance(result, list) + else [doc_res.document for doc_res in result.documents] + ) run_manager.on_retriever_end( - result, + callback_result, **kwargs, ) return result @@ -237,7 +299,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, **kwargs: Any, - ) -> List[Document]: + ) -> RetrieverOutput: """Asynchronously get documents relevant to a query. Args: query: string to find relevant documents for diff --git a/libs/core/langchain_core/vectorstores.py b/libs/core/langchain_core/vectorstores.py index 2fb32f86b1a..d4dc7a506e7 100644 --- a/libs/core/langchain_core/vectorstores.py +++ b/libs/core/langchain_core/vectorstores.py @@ -21,7 +21,7 @@ from typing import ( from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.retrievers import BaseRetriever +from langchain_core.retrievers import BaseRetriever, DocumentResult, RetrievalResult from langchain_core.runnables.config import run_in_executor if TYPE_CHECKING: @@ -650,22 +650,52 @@ class VectorStoreRetriever(BaseRetriever): def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: + result = self._get_relevant_results(query, run_manager=run_manager) + return [doc_result.document for doc_result in result.documents] + + def _get_relevant_results( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> RetrievalResult: + result_metadata = { + "query": query, + "search_type": self.search_type, + "search_kwargs": self.search_kwargs, + } if self.search_type == "similarity": - docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + docs_and_scores = self.vectorstore.similarity_search_with_score( + query, **self.search_kwargs + ) + result = RetrievalResult( + documents=[ + DocumentResult(document=doc, metadata={"score": score}) + for doc, score in docs_and_scores + ], + metadata=result_metadata, + ) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( self.vectorstore.similarity_search_with_relevance_scores( query, **self.search_kwargs ) ) - docs = [doc for doc, _ in docs_and_similarities] + result = RetrievalResult( + documents=[ + DocumentResult(document=doc, metadata={"relevance_score": score}) + for doc, score in docs_and_similarities + ], + metadata=result_metadata, + ) elif self.search_type == "mmr": docs = self.vectorstore.max_marginal_relevance_search( query, **self.search_kwargs ) + result = RetrievalResult( + documents=[DocumentResult(document=doc) for doc in docs], + metadata=result_metadata, + ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") - return docs + return result async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun