mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Add Runnable.bind method to attach kwargs to a Runnable that will be passed to all invoke/stream/batch calls when it is run (#8368)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
cf608f876b
commit
0eca3e7d90
@ -131,6 +131,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> AsyncIterator[Output]:
|
||||
yield await self.ainvoke(input, config)
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind arguments to a Runnable, returning a new Runnable.
|
||||
"""
|
||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||
|
||||
def _get_config_list(
|
||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||
) -> List[RunnableConfig]:
|
||||
@ -692,6 +698,60 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
kwargs: Mapping[str, Any]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
return self.bound.invoke(input, config, **self.kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
return await self.bound.ainvoke(input, config, **self.kwargs)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
return self.bound.batch(
|
||||
inputs, config, max_concurrency=max_concurrency, **self.kwargs
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
return await self.bound.abatch(
|
||||
inputs, config, max_concurrency=max_concurrency, **self.kwargs
|
||||
)
|
||||
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.stream(input, config, **self.kwargs)
|
||||
|
||||
async def astream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.astream(input, config, **self.kwargs):
|
||||
yield item
|
||||
|
||||
|
||||
def _patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||
) -> RunnableConfig:
|
||||
|
File diff suppressed because one or more lines are too long
@ -566,7 +566,7 @@ def test_seq_prompt_map(
|
||||
prompt
|
||||
| passthrough
|
||||
| {
|
||||
"chat": chat,
|
||||
"chat": chat.bind(stop=["Thought:"]),
|
||||
"llm": llm,
|
||||
"passthrough": passthrough,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user