mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
core, tests: more tolerant _aget_relevant_documents function (#28462)
This commit is contained in:
parent
bc636ccc60
commit
18386c16c7
@ -27,7 +27,7 @@ from inspect import signature
|
|||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import Self, TypedDict
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -180,6 +180,18 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|||||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||||
parameters = signature(cls._get_relevant_documents).parameters
|
parameters = signature(cls._get_relevant_documents).parameters
|
||||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||||
|
if (
|
||||||
|
not cls._new_arg_supported
|
||||||
|
and cls._aget_relevant_documents == BaseRetriever._aget_relevant_documents
|
||||||
|
):
|
||||||
|
# we need to tolerate no run_manager in _aget_relevant_documents signature
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self: Self, query: str
|
||||||
|
) -> list[Document]:
|
||||||
|
return await run_in_executor(None, self._get_relevant_documents, query) # type: ignore
|
||||||
|
|
||||||
|
cls._aget_relevant_documents = _aget_relevant_documents # type: ignore[assignment]
|
||||||
|
|
||||||
# If a V1 retriever broke the interface and expects additional arguments
|
# If a V1 retriever broke the interface and expects additional arguments
|
||||||
cls._expects_other_args = (
|
cls._expects_other_args = (
|
||||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||||
|
0
libs/core/tests/unit_tests/test_retrievers.py
Normal file
0
libs/core/tests/unit_tests/test_retrievers.py
Normal file
@ -1,6 +1,5 @@
|
|||||||
from typing import Any, Type
|
from typing import Any, Type
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
@ -11,9 +10,7 @@ class ParrotRetriever(BaseRetriever):
|
|||||||
parrot_name: str
|
parrot_name: str
|
||||||
k: int = 3
|
k: int = 3
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
|
||||||
) -> list[Document]:
|
|
||||||
k = kwargs.get("k", self.k)
|
k = kwargs.get("k", self.k)
|
||||||
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
|
return [Document(page_content=f"{self.parrot_name} says: {query}")] * k
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user