mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
Add Runnable.bind method to attach kwargs to a Runnable that will be passed to all invoke/stream/batch calls when it is run (#8368)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
cf608f876b
commit
0eca3e7d90
@ -131,6 +131,12 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
yield await self.ainvoke(input, config)
|
yield await self.ainvoke(input, config)
|
||||||
|
|
||||||
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
|
"""
|
||||||
|
Bind arguments to a Runnable, returning a new Runnable.
|
||||||
|
"""
|
||||||
|
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||||
|
|
||||||
def _get_config_list(
|
def _get_config_list(
|
||||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||||
) -> List[RunnableConfig]:
|
) -> List[RunnableConfig]:
|
||||||
@ -692,6 +698,60 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
return self._call_with_config(lambda x: x, input, config)
|
return self._call_with_config(lambda x: x, input, config)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||||
|
bound: Runnable[Input, Output]
|
||||||
|
|
||||||
|
kwargs: Mapping[str, Any]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
|
return self.bound.invoke(input, config, **self.kwargs)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Output:
|
||||||
|
return await self.bound.ainvoke(input, config, **self.kwargs)
|
||||||
|
|
||||||
|
def batch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
return self.bound.batch(
|
||||||
|
inputs, config, max_concurrency=max_concurrency, **self.kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def abatch(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
) -> List[Output]:
|
||||||
|
return await self.bound.abatch(
|
||||||
|
inputs, config, max_concurrency=max_concurrency, **self.kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Iterator[Output]:
|
||||||
|
yield from self.bound.stream(input, config, **self.kwargs)
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> AsyncIterator[Output]:
|
||||||
|
async for item in self.bound.astream(input, config, **self.kwargs):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
def _patch_config(
|
def _patch_config(
|
||||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -566,7 +566,7 @@ def test_seq_prompt_map(
|
|||||||
prompt
|
prompt
|
||||||
| passthrough
|
| passthrough
|
||||||
| {
|
| {
|
||||||
"chat": chat,
|
"chat": chat.bind(stop=["Thought:"]),
|
||||||
"llm": llm,
|
"llm": llm,
|
||||||
"passthrough": passthrough,
|
"passthrough": passthrough,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user