Add .with_config() method to Runnables, Add run_id, run_name to RunnableConfig (#9694)

- with_config() allows binding any config values to a Runnable, like
.bind() does for kwargs

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
This commit is contained in:
Nuno Campos 2023-09-01 15:48:46 +01:00 committed by GitHub
commit 50a5c5bcf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 328 additions and 61 deletions

View File

@ -68,6 +68,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
@ -89,6 +90,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"), callbacks=config.get("callbacks"),
tags=config.get("tags"), tags=config.get("tags"),
metadata=config.get("metadata"), metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs, **kwargs,
) )
@ -235,6 +237,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Execute the chain. """Execute the chain.
@ -276,6 +279,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
name=run_name,
) )
try: try:
outputs = ( outputs = (
@ -302,6 +306,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*, *,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Asynchronously execute the chain. """Asynchronously execute the chain.
@ -343,6 +348,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
inputs, inputs,
name=run_name,
) )
try: try:
outputs = ( outputs = (

View File

@ -60,6 +60,7 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInpu
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.schema.runnable import RunnableConfig from langchain.schema.runnable import RunnableConfig
from langchain.schema.runnable.config import get_config_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -265,7 +266,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
config = self._get_config_list(config, len(inputs)) config = get_config_list(config, len(inputs))
if max_concurrency is None: if max_concurrency is None:
llm_result = self.generate_prompt( llm_result = self.generate_prompt(
@ -300,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, self.batch, inputs, config, max_concurrency None, self.batch, inputs, config, max_concurrency
) )
config = self._get_config_list(config, len(inputs)) config = get_config_list(config, len(inputs))
if max_concurrency is None: if max_concurrency is None:
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(

View File

@ -42,6 +42,7 @@ from langchain.schema.runnable.config import (
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
get_config_list,
get_executor_for_config, get_executor_for_config,
patch_config, patch_config,
) )
@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of batch, which calls invoke N times. Default implementation of batch, which calls invoke N times.
Subclasses should override this method if they can batch more efficiently. Subclasses should override this method if they can batch more efficiently.
""" """
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
# If there's only one input, don't bother with the executor # If there's only one input, don't bother with the executor
if len(inputs) == 1: if len(inputs) == 1:
@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of abatch, which calls ainvoke N times. Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently. Subclasses should override this method if they can batch more efficiently.
""" """
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs) coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
@ -210,7 +211,20 @@ class Runnable(Generic[Input, Output], ABC):
""" """
Bind arguments to a Runnable, returning a new Runnable. Bind arguments to a Runnable, returning a new Runnable.
""" """
return RunnableBinding(bound=self, kwargs=kwargs) return RunnableBinding(bound=self, kwargs=kwargs, config={})
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
"""
Bind config to a Runnable, returning a new Runnable.
"""
return RunnableBinding(
bound=self, config={**(config or {}), **kwargs}, kwargs={}
)
def map(self) -> Runnable[List[Input], List[Output]]: def map(self) -> Runnable[List[Input], List[Output]]:
""" """
@ -233,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC):
""" --- Helper methods for Subclasses --- """ """ --- Helper methods for Subclasses --- """
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def _call_with_config( def _call_with_config(
self, self,
func: Union[ func: Union[
@ -273,6 +266,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
input, input,
run_type=run_type, run_type=run_type,
name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(func): if accepts_run_manager_and_config(func):
@ -314,6 +308,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
input, input,
run_type=run_type, run_type=run_type,
name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(func): if accepts_run_manager_and_config(func):
@ -371,6 +366,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
run_type=run_type, run_type=run_type,
name=config.get("run_name"),
) )
try: try:
if accepts_run_manager_and_config(transformer): if accepts_run_manager_and_config(transformer):
@ -451,6 +447,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
run_type=run_type, run_type=run_type,
name=config.get("run_name"),
) )
try: try:
# mypy can't quite work out thew type guard here, but this is safe, # mypy can't quite work out thew type guard here, but this is safe,
@ -526,7 +523,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None first_error = None
for runnable in self.runnables: for runnable in self.runnables:
try: try:
@ -558,7 +557,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
first_error = None first_error = None
for runnable in self.runnables: for runnable in self.runnables:
@ -590,7 +591,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -606,9 +607,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input # start the root runs, one per input
run_managers = [ run_managers = [
cm.on_chain_start( cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
) )
for cm, input in zip(callback_managers, inputs) for cm, input, config in zip(callback_managers, inputs, configs)
] ]
first_error = None first_error = None
@ -648,7 +651,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import AsyncCallbackManager from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -664,8 +667,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
# start the root runs, one per input # start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
cm.on_chain_start(dumpd(self), input) cm.on_chain_start(
for cm, input in zip(callback_managers, inputs) dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
) )
) )
@ -770,7 +777,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
# invoke all steps in sequence # invoke all steps in sequence
try: try:
@ -798,7 +807,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
# invoke all steps in sequence # invoke all steps in sequence
try: try:
@ -825,7 +836,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
CallbackManager.configure( CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -840,8 +851,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
] ]
# start the root runs, one per input # start the root runs, one per input
run_managers = [ run_managers = [
cm.on_chain_start(dumpd(self), input) cm.on_chain_start(
for cm, input in zip(callback_managers, inputs) dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
] ]
# invoke # invoke
@ -876,7 +891,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) )
# setup callbacks # setup callbacks
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
AsyncCallbackManager.configure( AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
@ -892,8 +907,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# start the root runs, one per input # start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*( *(
cm.on_chain_start(dumpd(self), input) cm.on_chain_start(
for cm, input in zip(callback_managers, inputs) dumpd(self),
input,
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
) )
) )
@ -929,7 +948,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
streaming_start_index = 0 streaming_start_index = 0
@ -996,7 +1017,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last] steps = [self.first] + self.middle + [self.last]
streaming_start_index = len(steps) - 1 streaming_start_index = len(steps) - 1
@ -1127,7 +1150,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_metadata=None, local_metadata=None,
) )
# start the root run # start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), input) run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
# gather results from all steps # gather results from all steps
try: try:
@ -1166,7 +1191,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start(dumpd(self), input) run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
)
# gather results from all steps # gather results from all steps
try: try:
@ -1479,6 +1506,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
kwargs: Mapping[str, Any] kwargs: Mapping[str, Any]
config: Mapping[str, Any] = Field(default_factory=dict)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -1490,8 +1519,31 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def lc_namespace(self) -> List[str]: def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1] return self.__class__.__module__.split(".")[:-1]
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy[key] = config[key] or copy.get(key) # type: ignore
return copy
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
)
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config={**self.config, **(config or {}), **kwargs},
)
def invoke( def invoke(
self, self,
@ -1499,7 +1551,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
return self.bound.invoke(input, config, **{**self.kwargs, **kwargs}) return self.bound.invoke(
input,
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
async def ainvoke( async def ainvoke(
self, self,
@ -1507,7 +1563,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs}) return await self.bound.ainvoke(
input,
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
def batch( def batch(
self, self,
@ -1515,7 +1575,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs}) if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
)
else:
configs = [
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
async def abatch( async def abatch(
self, self,
@ -1523,7 +1592,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs}) if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
)
else:
configs = [
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
def stream( def stream(
self, self,
@ -1531,7 +1609,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs}) yield from self.bound.stream(
input,
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
async def astream( async def astream(
self, self,
@ -1540,7 +1622,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.astream( async for item in self.bound.astream(
input, config, **{**self.kwargs, **kwargs} input,
self._merge_config(config),
**{**self.kwargs, **kwargs},
): ):
yield item yield item
@ -1550,7 +1634,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs}) yield from self.bound.transform(
input,
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
async def atransform( async def atransform(
self, self,
@ -1559,11 +1647,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.atransform( async for item in self.bound.atransform(
input, config, **{**self.kwargs, **kwargs} input,
self._merge_config(config),
**{**self.kwargs, **kwargs},
): ):
yield item yield item
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
def coerce_to_runnable( def coerce_to_runnable(
thing: Union[ thing: Union[
Runnable[Input, Output], Runnable[Input, Output],

View File

@ -3,7 +3,9 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
from typing_extensions import TypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.base import BaseCallbackManager, Callbacks
@ -31,6 +33,11 @@ class RunnableConfig(TypedDict, total=False):
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
""" """
run_name: str
"""
Name for the tracer run for this call. Defaults to the name of the class.
"""
_locals: Dict[str, Any] _locals: Dict[str, Any]
""" """
Local variables Local variables
@ -48,7 +55,7 @@ class RunnableConfig(TypedDict, total=False):
""" """
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
empty = RunnableConfig( empty = RunnableConfig(
tags=[], tags=[],
metadata={}, metadata={},
@ -61,20 +68,52 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
return empty return empty
def get_config_list(
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
"""
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
)
def patch_config( def patch_config(
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
*, *,
deep_copy_locals: bool = False, deep_copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None, callbacks: Optional[BaseCallbackManager] = None,
recursion_limit: Optional[int] = None, recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = ensure_config(config) config = ensure_config(config)
if deep_copy_locals: if deep_copy_locals:
config["_locals"] = deepcopy(config["_locals"]) config["_locals"] = deepcopy(config["_locals"])
if callbacks is not None: if callbacks is not None:
# If we're replacing callbacks we need to unset run_name
# As that should apply only to the same run as the original callbacks
config["callbacks"] = callbacks config["callbacks"] = callbacks
if "run_name" in config:
del config["run_name"]
if recursion_limit is not None: if recursion_limit is not None:
config["recursion_limit"] = recursion_limit config["recursion_limit"] = recursion_limit
if max_concurrency is not None:
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
return config return config

View File

@ -23,7 +23,7 @@ from langchain.schema.runnable.base import (
RunnableSequence, RunnableSequence,
coerce_to_runnable, coerce_to_runnable,
) )
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig, get_config_list
from langchain.schema.runnable.utils import gather_with_concurrency from langchain.schema.runnable.utils import gather_with_concurrency
@ -131,7 +131,7 @@ class RouterRunnable(
raise ValueError("One or more keys do not have a corresponding runnable") raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys] runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor: with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list( return list(
executor.map( executor.map(
@ -156,7 +156,7 @@ class RouterRunnable(
raise ValueError("One or more keys do not have a corresponding runnable") raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys] runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
return await gather_with_concurrency( return await gather_with_concurrency(
max_concurrency, max_concurrency,
*( *(

View File

@ -2081,7 +2081,8 @@
"stop": [ "stop": [
"Thought:" "Thought:"
] ]
} },
"config": {}
} }
}, },
"llm": { "llm": {

View File

@ -11,6 +11,7 @@ from langchain import PromptTemplate
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain.load.dump import dumpd, dumps from langchain.load.dump import dumpd, dumps
@ -112,6 +113,124 @@ class FakeRetriever(BaseRetriever):
return [Document(page_content="foo"), Document(page_content="bar")] return [Document(page_content="foo"), Document(page_content="bar")]
@pytest.mark.asyncio
async def test_with_config(mocker: MockerFixture) -> None:
fake = FakeRunnable()
spy = mocker.spy(fake, "invoke")
assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
]
spy.reset_mock()
fake_1: Runnable = RunnablePassthrough()
fake_2: Runnable = RunnablePassthrough()
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
tags=["b-tag"], max_concurrency=5
)
assert sequence.invoke("hello") == "hello"
assert len(spy_seq_step.call_args_list) == 2
for i, call in enumerate(spy_seq_step.call_args_list):
assert call.args[1] == "hello"
if i == 0:
assert call.args[2].get("tags") == ["a-tag"]
assert call.args[2].get("max_concurrency") is None
else:
assert call.args[2].get("tags") == ["b-tag"]
assert call.args[2].get("max_concurrency") == 5
spy_seq_step.reset_mock()
assert [
*fake.with_config(tags=["a-tag"]).stream(
"hello", dict(metadata={"key": "value"})
)
] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.with_config(recursion_limit=5).batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}
spy.reset_mock()
assert fake.with_config(metadata={"a": "b"}).batch(
["hello", "wooorld"], dict(tags=["a-tag"])
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {"a": "b"}
spy.reset_mock()
handler = ConsoleCallbackHandler()
assert (
await fake.with_config(metadata={"a": "b"}).ainvoke(
"hello", config={"callbacks": [handler]}
)
== 5
)
assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})),
]
spy.reset_mock()
assert [
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"a": "b"})),
]
spy.reset_mock()
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
["hello", "wooorld"], dict(metadata={"key": "value"})
) == [
5,
7,
]
assert spy.call_args_list == [
mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
recursion_limit=5,
),
),
mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
recursion_limit=5,
),
),
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_default_method_implementations(mocker: MockerFixture) -> None: async def test_default_method_implementations(mocker: MockerFixture) -> None:
fake = FakeRunnable() fake = FakeRunnable()
@ -1125,6 +1244,14 @@ async def test_map_astream_iterator_input() -> None:
assert final_value.get("passthrough") == llm_res assert final_value.get("passthrough") == llm_res
def test_with_config_with_config() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])
assert dumpd(
llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"])
) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))
def test_bind_bind() -> None: def test_bind_bind() -> None:
llm = FakeListLLM(responses=["i'm a textbot"]) llm = FakeListLLM(responses=["i'm a textbot"])