"""**Retriever** class returns `Document` objects given a text **query**. It is more general than a vector store. A retriever does not need to be able to store documents, only to return (or retrieve) it. Vector stores can be used as the backbone of a retriever, but there are other types of retrievers as well. """ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature from typing import TYPE_CHECKING, Any from pydantic import ConfigDict from typing_extensions import Self, TypedDict, override from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain_core.documents import Document from langchain_core.runnables import ( Runnable, RunnableConfig, RunnableSerializable, ensure_config, ) from langchain_core.runnables.config import run_in_executor if TYPE_CHECKING: from langchain_core.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) RetrieverInput = str RetrieverOutput = list[Document] RetrieverLike = Runnable[RetrieverInput, RetrieverOutput] RetrieverOutputLike = Runnable[Any, RetrieverOutput] class LangSmithRetrieverParams(TypedDict, total=False): """LangSmith parameters for tracing.""" ls_retriever_name: str """Retriever name.""" ls_vector_store_provider: str | None """Vector store provider.""" ls_embedding_provider: str | None """Embedding provider.""" ls_embedding_model: str | None """Embedding model.""" class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): """Abstract base class for a document retrieval system. A retrieval system is defined as something that can take string queries and return the most 'relevant' documents from some source. Usage: A retriever follows the standard `Runnable` interface, and should be used via the standard `Runnable` methods of `invoke`, `ainvoke`, `batch`, `abatch`. Implementation: When implementing a custom retriever, the class should implement the `_get_relevant_documents` method to define the logic for retrieving documents. Optionally, an async native implementations can be provided by overriding the `_aget_relevant_documents` method. !!! example "Retriever that returns the first 5 documents from a list of documents" ```python from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever class SimpleRetriever(BaseRetriever): docs: list[Document] k: int = 5 def _get_relevant_documents(self, query: str) -> list[Document]: \"\"\"Return the first k documents from the list of documents\"\"\" return self.docs[:self.k] async def _aget_relevant_documents(self, query: str) -> list[Document]: \"\"\"(Optional) async native implementation.\"\"\" return self.docs[:self.k] ``` !!! example "Simple retriever based on a scikit-learn vectorizer" ```python from sklearn.metrics.pairwise import cosine_similarity class TFIDFRetriever(BaseRetriever, BaseModel): vectorizer: Any docs: list[Document] tfidf_array: Any k: int = 4 class Config: arbitrary_types_allowed = True def _get_relevant_documents(self, query: str) -> list[Document]: # Ip -- (n_docs,x), Op -- (n_docs,n_Feats) query_vec = self.vectorizer.transform([query]) # Op -- (n_docs,1) -- Cosine Sim with each doc results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] ``` """ model_config = ConfigDict( arbitrary_types_allowed=True, ) _new_arg_supported: bool = False _expects_other_args: bool = False tags: list[str] | None = None """Optional list of tags associated with the retriever. These tags will be associated with each call to this retriever, and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a retriever with its use case. """ metadata: dict[str, Any] | None = None """Optional metadata associated with the retriever. This metadata will be associated with each call to this retriever, and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a retriever with its use case. """ @override def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) parameters = signature(cls._get_relevant_documents).parameters cls._new_arg_supported = parameters.get("run_manager") is not None if ( not cls._new_arg_supported and cls._aget_relevant_documents == BaseRetriever._aget_relevant_documents ): # we need to tolerate no run_manager in _aget_relevant_documents signature async def _aget_relevant_documents( self: Self, query: str ) -> list[Document]: return await run_in_executor(None, self._get_relevant_documents, query) # type: ignore[call-arg] cls._aget_relevant_documents = _aget_relevant_documents # type: ignore[assignment] # If a V1 retriever broke the interface and expects additional arguments cls._expects_other_args = ( len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 ) def _get_ls_params(self, **_kwargs: Any) -> LangSmithRetrieverParams: """Get standard params for tracing.""" default_retriever_name = self.get_name() if default_retriever_name.startswith("Retriever"): default_retriever_name = default_retriever_name[9:] elif default_retriever_name.endswith("Retriever"): default_retriever_name = default_retriever_name[:-9] default_retriever_name = default_retriever_name.lower() return LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) @override def invoke( self, input: str, config: RunnableConfig | None = None, **kwargs: Any ) -> list[Document]: """Invoke the retriever to get relevant documents. Main entry point for synchronous retriever invocations. Args: input: The query string. config: Configuration for the retriever. **kwargs: Additional arguments to pass to the retriever. Returns: List of relevant documents. Examples: ```python retriever.invoke("query") ``` """ config = ensure_config(config) inheritable_metadata = { **(config.get("metadata") or {}), **self._get_ls_params(**kwargs), } callback_manager = CallbackManager.configure( config.get("callbacks"), None, verbose=kwargs.get("verbose", False), inheritable_tags=config.get("tags"), local_tags=self.tags, inheritable_metadata=inheritable_metadata, local_metadata=self.metadata, ) run_manager = callback_manager.on_retriever_start( None, input, name=config.get("run_name") or self.get_name(), run_id=kwargs.pop("run_id", None), ) try: kwargs_ = kwargs if self._expects_other_args else {} if self._new_arg_supported: result = self._get_relevant_documents( input, run_manager=run_manager, **kwargs_ ) else: result = self._get_relevant_documents(input, **kwargs_) except Exception as e: run_manager.on_retriever_error(e) raise else: run_manager.on_retriever_end( result, ) return result @override async def ainvoke( self, input: str, config: RunnableConfig | None = None, **kwargs: Any, ) -> list[Document]: """Asynchronously invoke the retriever to get relevant documents. Main entry point for asynchronous retriever invocations. Args: input: The query string. config: Configuration for the retriever. **kwargs: Additional arguments to pass to the retriever. Returns: List of relevant documents. Examples: ```python await retriever.ainvoke("query") ``` """ config = ensure_config(config) inheritable_metadata = { **(config.get("metadata") or {}), **self._get_ls_params(**kwargs), } callback_manager = AsyncCallbackManager.configure( config.get("callbacks"), None, verbose=kwargs.get("verbose", False), inheritable_tags=config.get("tags"), local_tags=self.tags, inheritable_metadata=inheritable_metadata, local_metadata=self.metadata, ) run_manager = await callback_manager.on_retriever_start( None, input, name=config.get("run_name") or self.get_name(), run_id=kwargs.pop("run_id", None), ) try: kwargs_ = kwargs if self._expects_other_args else {} if self._new_arg_supported: result = await self._aget_relevant_documents( input, run_manager=run_manager, **kwargs_ ) else: result = await self._aget_relevant_documents(input, **kwargs_) except Exception as e: await run_manager.on_retriever_error(e) raise else: await run_manager.on_retriever_end( result, ) return result @abstractmethod def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: """Get documents relevant to a query. Args: query: String to find relevant documents for. run_manager: The callback handler to use. Returns: List of relevant documents. """ async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> list[Document]: """Asynchronously get documents relevant to a query. Args: query: String to find relevant documents for run_manager: The callback handler to use Returns: List of relevant documents """ return await run_in_executor( None, self._get_relevant_documents, query, run_manager=run_manager.get_sync(), )