mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
Add .with_config() method to Runnables, Add run_id, run_name to RunnableConfig (#9694)
- with_config() allows binding any config values to a Runnable, like .bind() does for kwargs <!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
commit
50a5c5bcf8
@ -68,6 +68,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
|
run_name=config.get("run_name"),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -89,6 +90,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
callbacks=config.get("callbacks"),
|
callbacks=config.get("callbacks"),
|
||||||
tags=config.get("tags"),
|
tags=config.get("tags"),
|
||||||
metadata=config.get("metadata"),
|
metadata=config.get("metadata"),
|
||||||
|
run_name=config.get("run_name"),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -235,6 +237,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Execute the chain.
|
"""Execute the chain.
|
||||||
@ -276,6 +279,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
|
name=run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
outputs = (
|
outputs = (
|
||||||
@ -302,6 +306,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Asynchronously execute the chain.
|
"""Asynchronously execute the chain.
|
||||||
@ -343,6 +348,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
|
name=run_name,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
outputs = (
|
outputs = (
|
||||||
|
@ -60,6 +60,7 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInpu
|
|||||||
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
|
||||||
from langchain.schema.output import GenerationChunk
|
from langchain.schema.output import GenerationChunk
|
||||||
from langchain.schema.runnable import RunnableConfig
|
from langchain.schema.runnable import RunnableConfig
|
||||||
|
from langchain.schema.runnable.config import get_config_list
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -265,7 +266,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
config = self._get_config_list(config, len(inputs))
|
config = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
if max_concurrency is None:
|
if max_concurrency is None:
|
||||||
llm_result = self.generate_prompt(
|
llm_result = self.generate_prompt(
|
||||||
@ -300,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
None, self.batch, inputs, config, max_concurrency
|
None, self.batch, inputs, config, max_concurrency
|
||||||
)
|
)
|
||||||
|
|
||||||
config = self._get_config_list(config, len(inputs))
|
config = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
if max_concurrency is None:
|
if max_concurrency is None:
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
|
@ -42,6 +42,7 @@ from langchain.schema.runnable.config import (
|
|||||||
ensure_config,
|
ensure_config,
|
||||||
get_async_callback_manager_for_config,
|
get_async_callback_manager_for_config,
|
||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
|
get_config_list,
|
||||||
get_executor_for_config,
|
get_executor_for_config,
|
||||||
patch_config,
|
patch_config,
|
||||||
)
|
)
|
||||||
@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Default implementation of batch, which calls invoke N times.
|
Default implementation of batch, which calls invoke N times.
|
||||||
Subclasses should override this method if they can batch more efficiently.
|
Subclasses should override this method if they can batch more efficiently.
|
||||||
"""
|
"""
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
# If there's only one input, don't bother with the executor
|
# If there's only one input, don't bother with the executor
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Default implementation of abatch, which calls ainvoke N times.
|
Default implementation of abatch, which calls ainvoke N times.
|
||||||
Subclasses should override this method if they can batch more efficiently.
|
Subclasses should override this method if they can batch more efficiently.
|
||||||
"""
|
"""
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
||||||
|
|
||||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
@ -210,7 +211,20 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
Bind arguments to a Runnable, returning a new Runnable.
|
Bind arguments to a Runnable, returning a new Runnable.
|
||||||
"""
|
"""
|
||||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
return RunnableBinding(bound=self, kwargs=kwargs, config={})
|
||||||
|
|
||||||
|
def with_config(
|
||||||
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
"""
|
||||||
|
Bind config to a Runnable, returning a new Runnable.
|
||||||
|
"""
|
||||||
|
return RunnableBinding(
|
||||||
|
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||||
"""
|
"""
|
||||||
@ -233,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
|
|
||||||
""" --- Helper methods for Subclasses --- """
|
""" --- Helper methods for Subclasses --- """
|
||||||
|
|
||||||
def _get_config_list(
|
|
||||||
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
|
||||||
) -> List[RunnableConfig]:
|
|
||||||
"""
|
|
||||||
Helper method to get a list of configs from a single config or a list of
|
|
||||||
configs, useful for subclasses overriding batch() or abatch().
|
|
||||||
"""
|
|
||||||
if length < 1:
|
|
||||||
raise ValueError(f"length must be >= 1, but got {length}")
|
|
||||||
if isinstance(config, list) and len(config) != length:
|
|
||||||
raise ValueError(
|
|
||||||
f"config must be a list of the same length as inputs, "
|
|
||||||
f"but got {len(config)} configs for {length} inputs"
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
list(map(ensure_config, config))
|
|
||||||
if isinstance(config, list)
|
|
||||||
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
self,
|
self,
|
||||||
func: Union[
|
func: Union[
|
||||||
@ -273,6 +266,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
input,
|
input,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
|
name=config.get("run_name"),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if accepts_run_manager_and_config(func):
|
if accepts_run_manager_and_config(func):
|
||||||
@ -314,6 +308,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
input,
|
input,
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
|
name=config.get("run_name"),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if accepts_run_manager_and_config(func):
|
if accepts_run_manager_and_config(func):
|
||||||
@ -371,6 +366,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
|
name=config.get("run_name"),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if accepts_run_manager_and_config(transformer):
|
if accepts_run_manager_and_config(transformer):
|
||||||
@ -451,6 +447,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
run_type=run_type,
|
run_type=run_type,
|
||||||
|
name=config.get("run_name"),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# mypy can't quite work out thew type guard here, but this is safe,
|
# mypy can't quite work out thew type guard here, but this is safe,
|
||||||
@ -526,7 +523,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
first_error = None
|
first_error = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
try:
|
try:
|
||||||
@ -558,7 +557,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
first_error = None
|
first_error = None
|
||||||
for runnable in self.runnables:
|
for runnable in self.runnables:
|
||||||
@ -590,7 +591,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
CallbackManager.configure(
|
CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
@ -606,9 +607,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers = [
|
run_managers = [
|
||||||
cm.on_chain_start(
|
cm.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self),
|
||||||
|
input if isinstance(input, dict) else {"input": input},
|
||||||
|
name=config.get("run_name"),
|
||||||
)
|
)
|
||||||
for cm, input in zip(callback_managers, inputs)
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
]
|
]
|
||||||
|
|
||||||
first_error = None
|
first_error = None
|
||||||
@ -648,7 +651,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
AsyncCallbackManager.configure(
|
AsyncCallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
@ -664,8 +667,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
cm.on_chain_start(dumpd(self), input)
|
cm.on_chain_start(
|
||||||
for cm, input in zip(callback_managers, inputs)
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -770,7 +777,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
@ -798,7 +807,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
# invoke all steps in sequence
|
# invoke all steps in sequence
|
||||||
try:
|
try:
|
||||||
@ -825,7 +836,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
CallbackManager.configure(
|
CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
@ -840,8 +851,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
]
|
]
|
||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers = [
|
run_managers = [
|
||||||
cm.on_chain_start(dumpd(self), input)
|
cm.on_chain_start(
|
||||||
for cm, input in zip(callback_managers, inputs)
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
]
|
]
|
||||||
|
|
||||||
# invoke
|
# invoke
|
||||||
@ -876,7 +891,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
AsyncCallbackManager.configure(
|
AsyncCallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
@ -892,8 +907,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
# start the root runs, one per input
|
# start the root runs, one per input
|
||||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
*(
|
*(
|
||||||
cm.on_chain_start(dumpd(self), input)
|
cm.on_chain_start(
|
||||||
for cm, input in zip(callback_managers, inputs)
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -929,7 +948,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_callback_manager_for_config(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first] + self.middle + [self.last]
|
||||||
streaming_start_index = 0
|
streaming_start_index = 0
|
||||||
@ -996,7 +1017,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first] + self.middle + [self.last]
|
||||||
streaming_start_index = len(steps) - 1
|
streaming_start_index = len(steps) - 1
|
||||||
@ -1127,7 +1150,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
local_metadata=None,
|
local_metadata=None,
|
||||||
)
|
)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
@ -1166,7 +1191,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = get_async_callback_manager_for_config(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
dumpd(self), input, name=config.get("run_name")
|
||||||
|
)
|
||||||
|
|
||||||
# gather results from all steps
|
# gather results from all steps
|
||||||
try:
|
try:
|
||||||
@ -1479,6 +1506,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
kwargs: Mapping[str, Any]
|
kwargs: Mapping[str, Any]
|
||||||
|
|
||||||
|
config: Mapping[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -1490,8 +1519,31 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
def lc_namespace(self) -> List[str]:
|
def lc_namespace(self) -> List[str]:
|
||||||
return self.__class__.__module__.split(".")[:-1]
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
|
copy = cast(RunnableConfig, dict(self.config))
|
||||||
|
if config:
|
||||||
|
for key in config:
|
||||||
|
# Even though the keys aren't literals this is correct
|
||||||
|
# because both dicts are same type
|
||||||
|
copy[key] = config[key] or copy.get(key) # type: ignore
|
||||||
|
return copy
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
|
return self.__class__(
|
||||||
|
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_config(
|
||||||
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound,
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config={**self.config, **(config or {}), **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
@ -1499,7 +1551,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return self.bound.invoke(input, config, **{**self.kwargs, **kwargs})
|
return self.bound.invoke(
|
||||||
|
input,
|
||||||
|
self._merge_config(config),
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@ -1507,7 +1563,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs})
|
return await self.bound.ainvoke(
|
||||||
|
input,
|
||||||
|
self._merge_config(config),
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
@ -1515,7 +1575,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
|
if isinstance(config, list):
|
||||||
|
configs = cast(
|
||||||
|
List[RunnableConfig], [self._merge_config(conf) for conf in config]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configs = [
|
||||||
|
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||||
|
for _ in range(len(inputs))
|
||||||
|
]
|
||||||
|
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
@ -1523,7 +1592,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
|
if isinstance(config, list):
|
||||||
|
configs = cast(
|
||||||
|
List[RunnableConfig], [self._merge_config(conf) for conf in config]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configs = [
|
||||||
|
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||||
|
for _ in range(len(inputs))
|
||||||
|
]
|
||||||
|
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
@ -1531,7 +1609,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs})
|
yield from self.bound.stream(
|
||||||
|
input,
|
||||||
|
self._merge_config(config),
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
@ -1540,7 +1622,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.astream(
|
async for item in self.bound.astream(
|
||||||
input, config, **{**self.kwargs, **kwargs}
|
input,
|
||||||
|
self._merge_config(config),
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
@ -1550,7 +1634,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
|
yield from self.bound.transform(
|
||||||
|
input,
|
||||||
|
self._merge_config(config),
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
@ -1559,11 +1647,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.atransform(
|
async for item in self.bound.atransform(
|
||||||
input, config, **{**self.kwargs, **kwargs}
|
input,
|
||||||
|
self._merge_config(config),
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||||
|
|
||||||
|
|
||||||
def coerce_to_runnable(
|
def coerce_to_runnable(
|
||||||
thing: Union[
|
thing: Union[
|
||||||
Runnable[Input, Output],
|
Runnable[Input, Output],
|
||||||
|
@ -3,7 +3,9 @@ from __future__ import annotations
|
|||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
@ -31,6 +33,11 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
run_name: str
|
||||||
|
"""
|
||||||
|
Name for the tracer run for this call. Defaults to the name of the class.
|
||||||
|
"""
|
||||||
|
|
||||||
_locals: Dict[str, Any]
|
_locals: Dict[str, Any]
|
||||||
"""
|
"""
|
||||||
Local variables
|
Local variables
|
||||||
@ -48,7 +55,7 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||||
empty = RunnableConfig(
|
empty = RunnableConfig(
|
||||||
tags=[],
|
tags=[],
|
||||||
metadata={},
|
metadata={},
|
||||||
@ -61,20 +68,52 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
return empty
|
return empty
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_list(
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
|
||||||
|
) -> List[RunnableConfig]:
|
||||||
|
"""
|
||||||
|
Helper method to get a list of configs from a single config or a list of
|
||||||
|
configs, useful for subclasses overriding batch() or abatch().
|
||||||
|
"""
|
||||||
|
if length < 1:
|
||||||
|
raise ValueError(f"length must be >= 1, but got {length}")
|
||||||
|
if isinstance(config, list) and len(config) != length:
|
||||||
|
raise ValueError(
|
||||||
|
f"config must be a list of the same length as inputs, "
|
||||||
|
f"but got {len(config)} configs for {length} inputs"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
list(map(ensure_config, config))
|
||||||
|
if isinstance(config, list)
|
||||||
|
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
def patch_config(
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
*,
|
*,
|
||||||
deep_copy_locals: bool = False,
|
deep_copy_locals: bool = False,
|
||||||
callbacks: Optional[BaseCallbackManager] = None,
|
callbacks: Optional[BaseCallbackManager] = None,
|
||||||
recursion_limit: Optional[int] = None,
|
recursion_limit: Optional[int] = None,
|
||||||
|
max_concurrency: Optional[int] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if deep_copy_locals:
|
if deep_copy_locals:
|
||||||
config["_locals"] = deepcopy(config["_locals"])
|
config["_locals"] = deepcopy(config["_locals"])
|
||||||
if callbacks is not None:
|
if callbacks is not None:
|
||||||
|
# If we're replacing callbacks we need to unset run_name
|
||||||
|
# As that should apply only to the same run as the original callbacks
|
||||||
config["callbacks"] = callbacks
|
config["callbacks"] = callbacks
|
||||||
|
if "run_name" in config:
|
||||||
|
del config["run_name"]
|
||||||
if recursion_limit is not None:
|
if recursion_limit is not None:
|
||||||
config["recursion_limit"] = recursion_limit
|
config["recursion_limit"] = recursion_limit
|
||||||
|
if max_concurrency is not None:
|
||||||
|
config["max_concurrency"] = max_concurrency
|
||||||
|
if run_name is not None:
|
||||||
|
config["run_name"] = run_name
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ from langchain.schema.runnable.base import (
|
|||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
coerce_to_runnable,
|
coerce_to_runnable,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import RunnableConfig, get_config_list
|
||||||
from langchain.schema.runnable.utils import gather_with_concurrency
|
from langchain.schema.runnable.utils import gather_with_concurrency
|
||||||
|
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ class RouterRunnable(
|
|||||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
runnables = [self.runnables[key] for key in keys]
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||||
return list(
|
return list(
|
||||||
executor.map(
|
executor.map(
|
||||||
@ -156,7 +156,7 @@ class RouterRunnable(
|
|||||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||||
|
|
||||||
runnables = [self.runnables[key] for key in keys]
|
runnables = [self.runnables[key] for key in keys]
|
||||||
configs = self._get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
return await gather_with_concurrency(
|
return await gather_with_concurrency(
|
||||||
max_concurrency,
|
max_concurrency,
|
||||||
*(
|
*(
|
||||||
|
@ -2081,7 +2081,8 @@
|
|||||||
"stop": [
|
"stop": [
|
||||||
"Thought:"
|
"Thought:"
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"config": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"llm": {
|
"llm": {
|
||||||
|
@ -11,6 +11,7 @@ from langchain import PromptTemplate
|
|||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
||||||
from langchain.load.dump import dumpd, dumps
|
from langchain.load.dump import dumpd, dumps
|
||||||
@ -112,6 +113,124 @@ class FakeRetriever(BaseRetriever):
|
|||||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_config(mocker: MockerFixture) -> None:
|
||||||
|
fake = FakeRunnable()
|
||||||
|
spy = mocker.spy(fake, "invoke")
|
||||||
|
|
||||||
|
assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(tags=["a-tag"])),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
fake_1: Runnable = RunnablePassthrough()
|
||||||
|
fake_2: Runnable = RunnablePassthrough()
|
||||||
|
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
|
||||||
|
|
||||||
|
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
|
||||||
|
tags=["b-tag"], max_concurrency=5
|
||||||
|
)
|
||||||
|
assert sequence.invoke("hello") == "hello"
|
||||||
|
assert len(spy_seq_step.call_args_list) == 2
|
||||||
|
for i, call in enumerate(spy_seq_step.call_args_list):
|
||||||
|
assert call.args[1] == "hello"
|
||||||
|
if i == 0:
|
||||||
|
assert call.args[2].get("tags") == ["a-tag"]
|
||||||
|
assert call.args[2].get("max_concurrency") is None
|
||||||
|
else:
|
||||||
|
assert call.args[2].get("tags") == ["b-tag"]
|
||||||
|
assert call.args[2].get("max_concurrency") == 5
|
||||||
|
spy_seq_step.reset_mock()
|
||||||
|
|
||||||
|
assert [
|
||||||
|
*fake.with_config(tags=["a-tag"]).stream(
|
||||||
|
"hello", dict(metadata={"key": "value"})
|
||||||
|
)
|
||||||
|
] == [5]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert fake.with_config(recursion_limit=5).batch(
|
||||||
|
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||||
|
) == [5, 7]
|
||||||
|
|
||||||
|
assert len(spy.call_args_list) == 2
|
||||||
|
for i, call in enumerate(spy.call_args_list):
|
||||||
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
|
if i == 0:
|
||||||
|
assert call.args[1].get("recursion_limit") == 5
|
||||||
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
|
assert call.args[1].get("metadata") == {}
|
||||||
|
else:
|
||||||
|
assert call.args[1].get("recursion_limit") == 5
|
||||||
|
assert call.args[1].get("tags") == []
|
||||||
|
assert call.args[1].get("metadata") == {"key": "value"}
|
||||||
|
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert fake.with_config(metadata={"a": "b"}).batch(
|
||||||
|
["hello", "wooorld"], dict(tags=["a-tag"])
|
||||||
|
) == [5, 7]
|
||||||
|
assert len(spy.call_args_list) == 2
|
||||||
|
for i, call in enumerate(spy.call_args_list):
|
||||||
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
|
assert call.args[1].get("metadata") == {"a": "b"}
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
handler = ConsoleCallbackHandler()
|
||||||
|
assert (
|
||||||
|
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
||||||
|
"hello", config={"callbacks": [handler]}
|
||||||
|
)
|
||||||
|
== 5
|
||||||
|
)
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert [
|
||||||
|
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
|
||||||
|
] == [5]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(metadata={"a": "b"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
|
||||||
|
["hello", "wooorld"], dict(metadata={"key": "value"})
|
||||||
|
) == [
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call(
|
||||||
|
"hello",
|
||||||
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=["c"],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
recursion_limit=5,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
mocker.call(
|
||||||
|
"wooorld",
|
||||||
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=["c"],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
recursion_limit=5,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||||
fake = FakeRunnable()
|
fake = FakeRunnable()
|
||||||
@ -1125,6 +1244,14 @@ async def test_map_astream_iterator_input() -> None:
|
|||||||
assert final_value.get("passthrough") == llm_res
|
assert final_value.get("passthrough") == llm_res
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_config_with_config() -> None:
|
||||||
|
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||||
|
|
||||||
|
assert dumpd(
|
||||||
|
llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"])
|
||||||
|
) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))
|
||||||
|
|
||||||
|
|
||||||
def test_bind_bind() -> None:
|
def test_bind_bind() -> None:
|
||||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user