From 7fde2791dc7464d43c293e8e8fc403c8a115f8c6 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 30 Sep 2024 16:45:29 -0400 Subject: [PATCH] core[patch]: Add kwargs to Runnable (#27008) Fixes #26685 --------- Co-authored-by: Tibor Reiss --- .../langchain_community/utilities/nvidia_riva.py | 8 ++++++-- .../langchain_community/vectorstores/vectara.py | 1 + libs/core/langchain_core/beta/runnables/context.py | 8 ++++++-- libs/core/langchain_core/output_parsers/base.py | 10 ++++++++-- libs/core/langchain_core/prompts/base.py | 2 +- libs/core/langchain_core/runnables/base.py | 6 ++++-- libs/core/langchain_core/runnables/router.py | 2 +- .../tests/unit_tests/runnables/test_configurable.py | 8 ++++++-- libs/core/tests/unit_tests/runnables/test_graph.py | 4 +++- .../tests/unit_tests/runnables/test_runnable.py | 13 +++++++++---- .../unit_tests/runnables/test_runnable_events_v2.py | 4 +++- .../langchain/agents/openai_assistant/base.py | 2 +- 12 files changed, 49 insertions(+), 19 deletions(-) diff --git a/libs/community/langchain_community/utilities/nvidia_riva.py b/libs/community/langchain_community/utilities/nvidia_riva.py index c798935e2a7..40a788a85f7 100644 --- a/libs/community/langchain_community/utilities/nvidia_riva.py +++ b/libs/community/langchain_community/utilities/nvidia_riva.py @@ -471,7 +471,8 @@ class RivaASR( def invoke( self, input: ASRInputType, - _: Optional[RunnableConfig] = None, + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> ASROutputType: """Transcribe the audio bytes into a string with Riva.""" # create an output text generator with Riva @@ -567,7 +568,10 @@ class RivaTTS( ) from err def invoke( - self, input: TTSInputType, _: Union[RunnableConfig, None] = None + self, + input: TTSInputType, + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> TTSOutputType: """Perform TTS by taking a string and outputting the entire audio file.""" return b"".join(self.transform(iter([input]))) diff --git a/libs/community/langchain_community/vectorstores/vectara.py b/libs/community/langchain_community/vectorstores/vectara.py index 2d217333f83..8683e22e2d8 100644 --- a/libs/community/langchain_community/vectorstores/vectara.py +++ b/libs/community/langchain_community/vectorstores/vectara.py @@ -888,6 +888,7 @@ class VectaraRAG(Runnable): self, input: str, config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> dict: res = {"answer": ""} for chunk in self.stream(input): diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 36222249d16..739798eb172 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -183,7 +183,9 @@ class ContextGet(RunnableSerializable): for id_ in self.ids ] - def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + def invoke( + self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: config = ensure_config(config) configurable = config.get("configurable", {}) if isinstance(self.key, list): @@ -280,7 +282,9 @@ class ContextSet(RunnableSerializable): for id_ in self.ids ] - def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + def invoke( + self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: config = ensure_config(config) configurable = config.get("configurable", {}) for id_, mapper in zip(self.ids, self.keys.values()): diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index bbecc94ee38..9c132b4a0bf 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -80,7 +80,10 @@ class BaseGenerationOutputParser( return T # type: ignore[misc] def invoke( - self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None + self, + input: Union[str, BaseMessage], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): return self._call_with_config( @@ -180,7 +183,10 @@ class BaseOutputParser( ) def invoke( - self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None + self, + input: Union[str, BaseMessage], + config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> T: if isinstance(input, BaseMessage): return self._call_with_config( diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index a5a1c85400d..f6932d6060a 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -174,7 +174,7 @@ class BasePromptTemplate( return await self.aformat_prompt(**_inner_input) def invoke( - self, input: dict, config: Optional[RunnableConfig] = None + self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> PromptValue: """Invoke the prompt. diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index cab3c40f7c6..730cd38931d 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -723,7 +723,9 @@ class Runnable(Generic[Input, Output], ABC): """ --- Public API --- """ @abstractmethod - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: """Transform a single input into an output. Override to implement. Args: @@ -3669,7 +3671,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): return "{\n " + map_for_repr + "\n}" def invoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> dict[str, Any]: from langchain_core.callbacks.manager import CallbackManager diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index 43b4761bbd5..c71cb85e4f7 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -96,7 +96,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): return ["langchain", "schema", "runnable"] def invoke( - self, input: RouterInput, config: Optional[RunnableConfig] = None + self, input: RouterInput, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: key = input["key"] actual_input = input["input"] diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index 8bd6e760677..99f10f6c600 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -31,7 +31,9 @@ class MyRunnable(RunnableSerializable[str, str]): self._my_hidden_property = self.my_property return self - def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: + def invoke( + self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: return input + self._my_hidden_property def my_custom_function(self) -> str: @@ -51,7 +53,9 @@ class MyRunnable(RunnableSerializable[str, str]): class MyOtherRunnable(RunnableSerializable[str, str]): my_other_property: str - def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any: + def invoke( + self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: return input + self.my_other_property def my_other_custom_function(self) -> str: diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 2699e6cab48..870842bcc49 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel from syrupy import SnapshotAssertion @@ -363,6 +363,7 @@ def test_runnable_get_graph_with_invalid_input_type() -> None: self, input: int, config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> int: return input @@ -387,6 +388,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None: self, input: int, config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> int: return input diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 09de13b36bd..211fab46833 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -193,6 +193,7 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]): self, input: str, config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> int: return len(input) @@ -3966,7 +3967,9 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: def __init__(self, fail_starts_with: str) -> None: self.fail_starts_with = fail_starts_with - def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + def invoke( + self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: raise NotImplementedError() def _batch( @@ -4085,7 +4088,9 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: def __init__(self, fail_starts_with: str) -> None: self.fail_starts_with = fail_starts_with - def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + def invoke( + self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Any: raise NotImplementedError() async def _abatch( @@ -5205,7 +5210,7 @@ def test_default_transform_with_dicts() -> None: class CustomRunnable(RunnableSerializable[Input, Output]): def invoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: return cast(Output, input) # type: ignore @@ -5226,7 +5231,7 @@ async def test_default_atransform_with_dicts() -> None: class CustomRunnable(RunnableSerializable[Input, Output]): def invoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: return cast(Output, input) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index ac355dddf2e..4014c9687e9 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -2053,7 +2053,9 @@ class StreamingRunnable(Runnable[Input, Output]): """Initialize the runnable.""" self.iterable = iterable - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: """Invoke the runnable.""" raise ValueError("Server side error") diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 0fbbbe85e8b..dc85fb03037 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -277,7 +277,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): return cls(assistant_id=assistant.id, client=client, **kwargs) def invoke( - self, input: dict, config: Optional[RunnableConfig] = None + self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> OutputType: """Invoke assistant.