Compare commits

...

1 Commits

Author SHA1 Message Date
Eugene Yurtsev
54acb00d30 x 2024-04-19 17:50:49 -04:00

View File

@@ -33,7 +33,11 @@ from langchain_core.runnables import (
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
from langchain_core.runnables.config import (
get_callback_manager_for_config,
patch_config,
run_in_executor,
)
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@@ -115,7 +119,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
arbitrary_types_allowed = True
_new_arg_supported: bool = False
_accepts_run_manager: bool = False
_expects_other_args: bool = False
tags: Optional[List[str]] = None
"""Optional list of tags associated with the retriever. Defaults to None
@@ -162,10 +166,11 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
cls._accepts_run_manager = parameters.get("run_manager") is not None
cls._accepts_config = parameters.get("config") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
len(set(parameters.keys()) - {"self", "query", "run_manager", "config"}) > 0
)
def invoke(
@@ -190,14 +195,23 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
retriever.invoke("query")
"""
config = ensure_config(config)
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
if self._accepts_config:
return self._call_with_config_2(
input,
config,
**kwargs,
)
else:
# Then using old code path
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
async def ainvoke(
self,
@@ -224,6 +238,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
await retriever.ainvoke("query")
"""
config = ensure_config(config)
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
@@ -235,18 +250,28 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
config: Optional[RunnableConfig] = None,
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
config: Optional[RunnableConfig] = None,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
@@ -259,9 +284,59 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
None,
self._get_relevant_documents,
query,
run_manager=run_manager.get_sync(),
run_manager=run_manager.get_sync() if run_manager else None,
config=config,
)
def _call_with_config_2(
self,
query: str,
config: RunnableConfig,
**kwargs,
):
from langchain_core.callbacks.manager import CallbackManager
config = ensure_config(config)
callbacks = config.get("callbacks")
callback_manager = CallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags"),
local_tags=self.tags,
inheritable_metadata=config.get("tags"),
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
name=config.get("run_name"),
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._accepts_run_manager:
result = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
elif self._accepts_config:
result = self._get_relevant_documents(
query,
config=patch_config(config, callbacks=run_manager.get_child()),
**_kwargs,
)
else:
result = self._get_relevant_documents(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
)
return result
def get_relevant_documents(
self,
query: str,
@@ -291,39 +366,16 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Returns:
List of relevant documents
"""
from langchain_core.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
return self._call_with_config_2(
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
RunnableConfig(
callbacks=callbacks,
tags=tags,
metadata=metadata,
run_name=run_name,
),
**kwargs,
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = self._get_relevant_documents(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
)
return result
async def aget_relevant_documents(
self,
@@ -373,7 +425,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
if self._accepts_run_manager:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **_kwargs
)