mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
core, tests: more tolerant _aget_relevant_documents function (#28462)
This commit is contained in:
parent
bc636ccc60
commit
18386c16c7
@ -27,7 +27,7 @@ from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import Self, TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
@ -180,6 +180,18 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._get_relevant_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
if (
|
||||
not cls._new_arg_supported
|
||||
and cls._aget_relevant_documents == BaseRetriever._aget_relevant_documents
|
||||
):
|
||||
# we need to tolerate no run_manager in _aget_relevant_documents signature
|
||||
async def _aget_relevant_documents(
|
||||
self: Self, query: str
|
||||
) -> list[Document]:
|
||||
return await run_in_executor(None, self._get_relevant_documents, query) # type: ignore
|
||||
|
||||
cls._aget_relevant_documents = _aget_relevant_documents # type: ignore[assignment]
|
||||
|
||||
# If a V1 retriever broke the interface and expects additional arguments
|
||||
cls._expects_other_args = (
|
||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||
|
0
libs/core/tests/unit_tests/test_retrievers.py
Normal file
0
libs/core/tests/unit_tests/test_retrievers.py
Normal file
@ -1,6 +1,5 @@
|
||||
from typing import Any, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
@ -11,9 +10,7 @@ class ParrotRetriever(BaseRetriever):
|
||||
parrot_name: str
|
||||
k: int = 3
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> list[Document]:
|
||||
def _get_relevant_documents(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
k = kwargs.get("k", self.k)
|
||||
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user