Add kwargs to all other optional runnable methods (#9439)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
This commit is contained in:
Nuno Campos 2023-08-18 15:04:26 +01:00 committed by GitHub
parent 463019ac3e
commit d5eb228874
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 114 additions and 34 deletions

View File

@ -79,7 +79,10 @@ class BaseGenerationOutputParser(
) )
async def ainvoke( async def ainvoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None self,
input: str | BaseMessage,
config: RunnableConfig | None = None,
**kwargs: Optional[Any],
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return await self._acall_with_config( return await self._acall_with_config(
@ -147,7 +150,10 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
) )
async def ainvoke( async def ainvoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None self,
input: str | BaseMessage,
config: RunnableConfig | None = None,
**kwargs: Optional[Any],
) -> T: ) -> T:
if isinstance(input, BaseMessage): if isinstance(input, BaseMessage):
return await self._acall_with_config( return await self._acall_with_config(

View File

@ -116,7 +116,10 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
) )
async def ainvoke( async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> List[Document]: ) -> List[Document]:
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents: if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
# If the retriever doesn't implement async, use default implementation # If the retriever doesn't implement async, use default implementation

View File

@ -5,6 +5,7 @@ import copy
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from functools import partial
from itertools import tee from itertools import tee
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -83,14 +84,14 @@ class Runnable(Generic[Input, Output], ABC):
... ...
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
""" """
Default implementation of ainvoke, which calls invoke in a thread pool. Default implementation of ainvoke, which calls invoke in a thread pool.
Subclasses should override this method if they can run asynchronously. Subclasses should override this method if they can run asynchronously.
""" """
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, self.invoke, input, config None, partial(self.invoke, **kwargs), input, config
) )
def batch( def batch(
@ -99,6 +100,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
""" """
Default implementation of batch, which calls invoke N times. Default implementation of batch, which calls invoke N times.
@ -108,10 +110,10 @@ class Runnable(Generic[Input, Output], ABC):
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0])] return [self.invoke(inputs[0], configs[0], **kwargs)]
with ThreadPoolExecutor(max_workers=max_concurrency) as executor: with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(self.invoke, inputs, configs)) return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
async def abatch( async def abatch(
self, self,
@ -119,33 +121,40 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
""" """
Default implementation of abatch, which calls ainvoke N times. Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently. Subclasses should override this method if they can batch more efficiently.
""" """
configs = self._get_config_list(config, len(inputs)) configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs) coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
return await gather_with_concurrency(max_concurrency, *coros) return await gather_with_concurrency(max_concurrency, *coros)
def stream( def stream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
""" """
Default implementation of stream, which calls invoke. Default implementation of stream, which calls invoke.
Subclasses should override this method if they support streaming output. Subclasses should override this method if they support streaming output.
""" """
yield self.invoke(input, config) yield self.invoke(input, config, **kwargs)
async def astream( async def astream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
""" """
Default implementation of astream, which calls ainvoke. Default implementation of astream, which calls ainvoke.
Subclasses should override this method if they support streaming output. Subclasses should override this method if they support streaming output.
""" """
yield await self.ainvoke(input, config) yield await self.ainvoke(input, config, **kwargs)
def transform( def transform(
self, self,
@ -601,7 +610,10 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
raise first_error raise first_error
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager from langchain.callbacks.manager import AsyncCallbackManager
@ -650,6 +662,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
@ -712,6 +725,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManager, AsyncCallbackManager,
@ -879,7 +893,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
return cast(Output, input) return cast(Output, input)
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager from langchain.callbacks.manager import AsyncCallbackManager
@ -923,6 +940,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
@ -976,6 +994,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManager, AsyncCallbackManager,
@ -1034,7 +1053,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
return cast(List[Output], inputs) return cast(List[Output], inputs)
def stream( def stream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
@ -1111,7 +1133,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) )
async def astream( async def astream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
from langchain.callbacks.manager import AsyncCallbackManager from langchain.callbacks.manager import AsyncCallbackManager
@ -1280,7 +1305,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
return output return output
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
from langchain.callbacks.manager import AsyncCallbackManager from langchain.callbacks.manager import AsyncCallbackManager
@ -1379,7 +1407,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
) )
def stream( def stream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Dict[str, Any]]: ) -> Iterator[Dict[str, Any]]:
yield from self.transform(iter([input]), config) yield from self.transform(iter([input]), config)
@ -1443,7 +1474,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
yield chunk yield chunk
async def astream( async def astream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Dict[str, Any]]: ) -> AsyncIterator[Dict[str, Any]]:
async def input_aiter() -> AsyncIterator[Input]: async def input_aiter() -> AsyncIterator[Input]:
yield input yield input
@ -1472,7 +1506,12 @@ class RunnableLambda(Runnable[Input, Output]):
else: else:
return False return False
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
return self._call_with_config(self.func, input, config) return self._call_with_config(self.func, input, config)
@ -1499,13 +1538,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(
return self.bound.invoke(input, config, **self.kwargs) self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
return self.bound.invoke(input, config, **{**self.kwargs, **kwargs})
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
return await self.bound.ainvoke(input, config, **self.kwargs) return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs})
def batch( def batch(
self, self,
@ -1513,9 +1560,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return self.bound.batch( return self.bound.batch(
inputs, config, max_concurrency=max_concurrency, **self.kwargs inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
) )
async def abatch( async def abatch(
@ -1524,20 +1572,29 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return await self.bound.abatch( return await self.bound.abatch(
inputs, config, max_concurrency=max_concurrency, **self.kwargs inputs, config, max_concurrency=max_concurrency, **{**self.kwargs, **kwargs}
) )
def stream( def stream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.stream(input, config, **self.kwargs) yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs})
async def astream( async def astream(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.astream(input, config, **self.kwargs): async for item in self.bound.astream(
input, config, **{**self.kwargs, **kwargs}
):
yield item yield item
def transform( def transform(

View File

@ -32,7 +32,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
return self._call_with_config(identity, input, config) return self._call_with_config(identity, input, config)
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Input: ) -> Input:
return await self._acall_with_config(aidentity, input, config) return await self._acall_with_config(aidentity, input, config)

View File

@ -104,7 +104,10 @@ class RouterRunnable(
return runnable.invoke(actual_input, config) return runnable.invoke(actual_input, config)
async def ainvoke( async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
key = input["key"] key = input["key"]
actual_input = input["input"] actual_input = input["input"]
@ -120,6 +123,7 @@ class RouterRunnable(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
keys = [input["key"] for input in inputs] keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs] actual_inputs = [input["input"] for input in inputs]
@ -144,6 +148,7 @@ class RouterRunnable(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*, *,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
keys = [input["key"] for input in inputs] keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs] actual_inputs = [input["input"] for input in inputs]
@ -161,7 +166,10 @@ class RouterRunnable(
) )
def stream( def stream(
self, input: RouterInput, config: Optional[RunnableConfig] = None self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
key = input["key"] key = input["key"]
actual_input = input["input"] actual_input = input["input"]
@ -172,7 +180,10 @@ class RouterRunnable(
yield from runnable.stream(actual_input, config) yield from runnable.stream(actual_input, config)
async def astream( async def astream(
self, input: RouterInput, config: Optional[RunnableConfig] = None self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
key = input["key"] key = input["key"]
actual_input = input["input"] actual_input = input["input"]