core, tests: more tolerant _aget_relevant_documents function (#28462)

This commit is contained in:
Erick Friis 2024-12-05 16:49:30 -08:00 committed by GitHub
parent bc636ccc60
commit 18386c16c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 5 deletions

View File

@ -27,7 +27,7 @@ from inspect import signature
from typing import TYPE_CHECKING, Any, Optional
from pydantic import ConfigDict
from typing_extensions import TypedDict
from typing_extensions import Self, TypedDict
from langchain_core._api import deprecated
from langchain_core.documents import Document
@ -180,6 +180,18 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
if (
not cls._new_arg_supported
and cls._aget_relevant_documents == BaseRetriever._aget_relevant_documents
):
# we need to tolerate no run_manager in _aget_relevant_documents signature
async def _aget_relevant_documents(
self: Self, query: str
) -> list[Document]:
return await run_in_executor(None, self._get_relevant_documents, query) # type: ignore
cls._aget_relevant_documents = _aget_relevant_documents # type: ignore[assignment]
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0

View File

@ -1,6 +1,5 @@
from typing import Any, Type
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
@ -11,9 +10,7 @@ class ParrotRetriever(BaseRetriever):
parrot_name: str
k: int = 3
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> list[Document]:
def _get_relevant_documents(self, query: str, **kwargs: Any) -> list[Document]:
k = kwargs.get("k", self.k)
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k