mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 10:13:29 +00:00
core[patch]: Add kwargs to Runnable (#27008)
Fixes #26685 --------- Co-authored-by: Tibor Reiss <tibor.reiss@gmail.com>
This commit is contained in:
parent
2a6abd3f0a
commit
7fde2791dc
@ -471,7 +471,8 @@ class RivaASR(
|
||||
def invoke(
|
||||
self,
|
||||
input: ASRInputType,
|
||||
_: Optional[RunnableConfig] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> ASROutputType:
|
||||
"""Transcribe the audio bytes into a string with Riva."""
|
||||
# create an output text generator with Riva
|
||||
@ -567,7 +568,10 @@ class RivaTTS(
|
||||
) from err
|
||||
|
||||
def invoke(
|
||||
self, input: TTSInputType, _: Union[RunnableConfig, None] = None
|
||||
self,
|
||||
input: TTSInputType,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> TTSOutputType:
|
||||
"""Perform TTS by taking a string and outputting the entire audio file."""
|
||||
return b"".join(self.transform(iter([input])))
|
||||
|
@ -888,6 +888,7 @@ class VectaraRAG(Runnable):
|
||||
self,
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
res = {"answer": ""}
|
||||
for chunk in self.stream(input):
|
||||
|
@ -183,7 +183,9 @@ class ContextGet(RunnableSerializable):
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
if isinstance(self.key, list):
|
||||
@ -280,7 +282,9 @@ class ContextSet(RunnableSerializable):
|
||||
for id_ in self.ids
|
||||
]
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
config = ensure_config(config)
|
||||
configurable = config.get("configurable", {})
|
||||
for id_, mapper in zip(self.ids, self.keys.values()):
|
||||
|
@ -80,7 +80,10 @@ class BaseGenerationOutputParser(
|
||||
return T # type: ignore[misc]
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
@ -180,7 +183,10 @@ class BaseOutputParser(
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||
self,
|
||||
input: Union[str, BaseMessage],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
if isinstance(input, BaseMessage):
|
||||
return self._call_with_config(
|
||||
|
@ -174,7 +174,7 @@ class BasePromptTemplate(
|
||||
return await self.aformat_prompt(**_inner_input)
|
||||
|
||||
def invoke(
|
||||
self, input: dict, config: Optional[RunnableConfig] = None
|
||||
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> PromptValue:
|
||||
"""Invoke the prompt.
|
||||
|
||||
|
@ -723,7 +723,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
""" --- Public API --- """
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Transform a single input into an output. Override to implement.
|
||||
|
||||
Args:
|
||||
@ -3669,7 +3671,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
return "{\n " + map_for_repr + "\n}"
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
|
@ -96,7 +96,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
return ["langchain", "schema", "runnable"]
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
|
@ -31,7 +31,9 @@ class MyRunnable(RunnableSerializable[str, str]):
|
||||
self._my_hidden_property = self.my_property
|
||||
return self
|
||||
|
||||
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
return input + self._my_hidden_property
|
||||
|
||||
def my_custom_function(self) -> str:
|
||||
@ -51,7 +53,9 @@ class MyRunnable(RunnableSerializable[str, str]):
|
||||
class MyOtherRunnable(RunnableSerializable[str, str]):
|
||||
my_other_property: str
|
||||
|
||||
def invoke(self, input: str, config: Optional[RunnableConfig] = None) -> Any:
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
return input + self.my_other_property
|
||||
|
||||
def my_other_custom_function(self) -> str:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from syrupy import SnapshotAssertion
|
||||
@ -363,6 +363,7 @@ def test_runnable_get_graph_with_invalid_input_type() -> None:
|
||||
self,
|
||||
input: int,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
return input
|
||||
|
||||
@ -387,6 +388,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
|
||||
self,
|
||||
input: int,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
return input
|
||||
|
||||
|
@ -193,6 +193,7 @@ class FakeRunnableSerializable(RunnableSerializable[str, int]):
|
||||
self,
|
||||
input: str,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
return len(input)
|
||||
|
||||
@ -3966,7 +3967,9 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
def __init__(self, fail_starts_with: str) -> None:
|
||||
self.fail_starts_with = fail_starts_with
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _batch(
|
||||
@ -4085,7 +4088,9 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
def __init__(self, fail_starts_with: str) -> None:
|
||||
self.fail_starts_with = fail_starts_with
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
def invoke(
|
||||
self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _abatch(
|
||||
@ -5205,7 +5210,7 @@ def test_default_transform_with_dicts() -> None:
|
||||
|
||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return cast(Output, input) # type: ignore
|
||||
|
||||
@ -5226,7 +5231,7 @@ async def test_default_atransform_with_dicts() -> None:
|
||||
|
||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return cast(Output, input)
|
||||
|
||||
|
@ -2053,7 +2053,9 @@ class StreamingRunnable(Runnable[Input, Output]):
|
||||
"""Initialize the runnable."""
|
||||
self.iterable = iterable
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Invoke the runnable."""
|
||||
raise ValueError("Server side error")
|
||||
|
||||
|
@ -277,7 +277,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
return cls(assistant_id=assistant.id, client=client, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, input: dict, config: Optional[RunnableConfig] = None
|
||||
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> OutputType:
|
||||
"""Invoke assistant.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user