core[patch]: Add kwargs to Runnable (#27008)

Fixes #26685

---------

Co-authored-by: Tibor Reiss <tibor.reiss@gmail.com>
This commit is contained in:
Eugene Yurtsev 2024-09-30 16:45:29 -04:00 committed by GitHub
parent 2a6abd3f0a
commit 7fde2791dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 49 additions and 19 deletions

View File

@ -471,7 +471,8 @@ class RivaASR(
def invoke(
self,
input: ASRInputType,
_: Optional[RunnableConfig] = None,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> ASROutputType:
"""Transcribe the audio bytes into a string with Riva."""
# create an output text generator with Riva
@ -567,7 +568,10 @@ class RivaTTS(
) from err
def invoke(
self, input: TTSInputType, _: Union[RunnableConfig, None] = None
self,
input: TTSInputType,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> TTSOutputType:
"""Perform TTS by taking a string and outputting the entire audio file."""
return b"".join(self.transform(iter([input])))

View File

@ -888,6 +888,7 @@ class VectaraRAG(Runnable):
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> dict:
res = {"answer": ""}
for chunk in self.stream(input):

View File

@ -183,7 +183,9 @@ class ContextGet(RunnableSerializable):
for id_ in self.ids
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
if isinstance(self.key, list):
@ -280,7 +282,9 @@ class ContextSet(RunnableSerializable):
for id_ in self.ids
]
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
config = ensure_config(config)
configurable = config.get("configurable", {})
for id_, mapper in zip(self.ids, self.keys.values()):

View File

@ -80,7 +80,10 @@ class BaseGenerationOutputParser(
return T # type: ignore[misc]
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
self,
input: Union[str, BaseMessage],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
@ -180,7 +183,10 @@ class BaseOutputParser(
)
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
self,
input: Union[str, BaseMessage],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(

View File

@ -174,7 +174,7 @@ class BasePromptTemplate(
return await self.aformat_prompt(**_inner_input)
def invoke(
self, input: dict, config: Optional[RunnableConfig] = None
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
"""Invoke the prompt.

View File

@ -723,7 +723,9 @@ class Runnable(Generic[Input, Output], ABC):
""" --- Public API --- """
@abstractmethod
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Transform a single input into an output. Override to implement.
Args:
@ -3669,7 +3671,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
return "{\n " + map_for_repr + "\n}"
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> dict[str, Any]:
from langchain_core.callbacks.manager import CallbackManager

View File

@ -96,7 +96,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
return ["langchain", "schema", "runnable"]
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
self, input: RouterInput, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
key = input["key"]
actual_input = input["input"]

View File

@ -31,7 +31,9 @@ class MyRunnable(RunnableSerializable[str, str]):
self._my_hidden_property = self.my_property
return self
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
return input + self._my_hidden_property
def my_custom_function(self) -> str:
@ -51,7 +53,9 @@ class MyRunnable(RunnableSerializable[str, str]):
class MyOtherRunnable(RunnableSerializable[str, str]):
my_other_property: str
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
return input + self.my_other_property
def my_other_custom_function(self) -> str:

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
from syrupy import SnapshotAssertion
@ -363,6 +363,7 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
self,
input: int,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> int:
return input
@ -387,6 +388,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
self,
input: int,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> int:
return input

View File

@ -193,6 +193,7 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]):
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> int:
return len(input)
@ -3966,7 +3967,9 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
raise NotImplementedError()
def _batch(
@ -4085,7 +4088,9 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
def invoke(
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Any:
raise NotImplementedError()
async def _abatch(
@ -5205,7 +5210,7 @@ def test_default_transform_with_dicts() -> None:
class CustomRunnable(RunnableSerializable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return cast(Output, input) # type: ignore
@ -5226,7 +5231,7 @@ async def test_default_atransform_with_dicts() -> None:
class CustomRunnable(RunnableSerializable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return cast(Output, input)

View File

@ -2053,7 +2053,9 @@ class StreamingRunnable(Runnable[Input, Output]):
"""Initialize the runnable."""
self.iterable = iterable
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Invoke the runnable."""
raise ValueError("Server side error")

View File

@ -277,7 +277,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
return cls(assistant_id=assistant.id, client=client, **kwargs)
def invoke(
self, input: dict, config: Optional[RunnableConfig] = None
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> OutputType:
"""Invoke assistant.