mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
Runnable locals(#9007)
Adds Runnables that can manipulate variables local to a RunnableSequence run --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
8a03836160
commit
1c64db575c
@ -1,3 +1,4 @@
|
|||||||
|
from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableBinding,
|
RunnableBinding,
|
||||||
@ -11,6 +12,8 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough
|
|||||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"GetLocalVar",
|
||||||
|
"PutLocalVar",
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
"Runnable",
|
"Runnable",
|
||||||
|
156
libs/langchain/langchain/schema/runnable/_locals.py
Normal file
156
libs/langchain/langchain/schema/runnable/_locals.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.runnable.base import Input, Output, Runnable
|
||||||
|
from langchain.schema.runnable.config import RunnableConfig
|
||||||
|
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||||
|
|
||||||
|
|
||||||
|
class PutLocalVar(RunnablePassthrough):
|
||||||
|
key: Union[str, Mapping[str, str]]
|
||||||
|
"""The key(s) to use for storing the input variable(s) in local state.
|
||||||
|
|
||||||
|
If a string is provided then the entire input is stored under that key. If a
|
||||||
|
Mapping is provided, then the map values are gotten from the input and
|
||||||
|
stored in local state under the map keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
|
||||||
|
super().__init__(key=key, **kwargs)
|
||||||
|
|
||||||
|
def _concat_put(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
*,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
replace: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||||
|
"therefore always receive a non-null config."
|
||||||
|
)
|
||||||
|
if isinstance(self.key, str):
|
||||||
|
if self.key not in config["_locals"] or replace:
|
||||||
|
config["_locals"][self.key] = input
|
||||||
|
else:
|
||||||
|
config["_locals"][self.key] += input
|
||||||
|
elif isinstance(self.key, Mapping):
|
||||||
|
if not isinstance(input, Mapping):
|
||||||
|
raise TypeError(
|
||||||
|
f"Received key of type Mapping but input of type {type(input)}. "
|
||||||
|
f"input is expected to be of type Mapping when key is Mapping."
|
||||||
|
)
|
||||||
|
for input_key, put_key in self.key.items():
|
||||||
|
if put_key not in config["_locals"] or replace:
|
||||||
|
config["_locals"][put_key] = input[input_key]
|
||||||
|
else:
|
||||||
|
config["_locals"][put_key] += input[input_key]
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"`key` should be a string or Mapping[str, str], received type "
|
||||||
|
f"{(type(self.key))}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||||
|
self._concat_put(input, config=config, replace=True)
|
||||||
|
return super().invoke(input, config=config)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Input:
|
||||||
|
self._concat_put(input, config=config, replace=True)
|
||||||
|
return await super().ainvoke(input, config=config)
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self,
|
||||||
|
input: Iterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Iterator[Input]:
|
||||||
|
for chunk in super().transform(input, config=config):
|
||||||
|
self._concat_put(chunk, config=config)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def atransform(
|
||||||
|
self,
|
||||||
|
input: AsyncIterator[Input],
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Input]:
|
||||||
|
async for chunk in super().atransform(input, config=config):
|
||||||
|
self._concat_put(chunk, config=config)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class GetLocalVar(
|
||||||
|
Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]]
|
||||||
|
):
|
||||||
|
key: str
|
||||||
|
"""The key to extract from the local state."""
|
||||||
|
passthrough_key: Optional[str] = None
|
||||||
|
"""The key to use for passing through the invocation input.
|
||||||
|
|
||||||
|
If None, then only the value retrieved from local state is returned. Otherwise a
|
||||||
|
dictionary ``{self.key: <<retrieved_value>>, self.passthrough_key: <<input>>}``
|
||||||
|
is returned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||||
|
super().__init__(key=key, **kwargs)
|
||||||
|
|
||||||
|
def _get(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: Union[CallbackManagerForChainRun, Any],
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||||
|
if self.passthrough_key:
|
||||||
|
return {
|
||||||
|
self.key: config["_locals"][self.key],
|
||||||
|
self.passthrough_key: input,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return config["_locals"][self.key]
|
||||||
|
|
||||||
|
async def _aget(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||||
|
return self._get(input, run_manager, config)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
|
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||||
|
if config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"GetLocalVar should only be used in a RunnableSequence, and should "
|
||||||
|
"therefore always receive a non-null config."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._call_with_config(self._get, input, config)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: Input,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||||
|
if config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"GetLocalVar should only be used in a RunnableSequence, and should "
|
||||||
|
"therefore always receive a non-null config."
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._acall_with_config(self._aget, input, config)
|
@ -5,6 +5,7 @@ import copy
|
|||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import tee
|
from itertools import tee
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -34,11 +35,16 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
|
get_async_callback_manager_for_config,
|
||||||
|
get_callback_manager_for_config,
|
||||||
|
patch_config,
|
||||||
|
)
|
||||||
from langchain.schema.runnable.utils import (
|
from langchain.schema.runnable.utils import (
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
accepts_run_manager_and_config,
|
accepts_run_manager_and_config,
|
||||||
@ -238,9 +244,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
config
|
list(map(ensure_config, config))
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [config.copy() if config is not None else {} for _ in range(length)]
|
else [deepcopy(ensure_config(config)) for _ in range(length)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
@ -256,14 +262,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
from langchain.callbacks.manager import CallbackManager
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
config = config or {}
|
|
||||||
callback_manager = CallbackManager.configure(
|
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input,
|
input,
|
||||||
@ -303,14 +303,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
config = config or {}
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input,
|
input,
|
||||||
@ -358,8 +352,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""Helper method to transform an Iterator of Input values into an Iterator of
|
"""Helper method to transform an Iterator of Input values into an Iterator of
|
||||||
Output values, with callbacks.
|
Output values, with callbacks.
|
||||||
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = tee(input, 2)
|
input_for_tracing, input_for_transform = tee(input, 2)
|
||||||
# Start the input iterator to ensure the input runnable starts before this one
|
# Start the input iterator to ensure the input runnable starts before this one
|
||||||
@ -368,12 +360,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
@ -444,8 +432,6 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""Helper method to transform an Async Iterator of Input values into an Async
|
"""Helper method to transform an Async Iterator of Input values into an Async
|
||||||
Iterator of Output values, with callbacks.
|
Iterator of Output values, with callbacks.
|
||||||
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# tee the input so we can iterate over it twice
|
# tee the input so we can iterate over it twice
|
||||||
input_for_tracing, input_for_transform = atee(input, 2)
|
input_for_tracing, input_for_transform = atee(input, 2)
|
||||||
# Start the input iterator to ensure the input runnable starts before this one
|
# Start the input iterator to ensure the input runnable starts before this one
|
||||||
@ -454,12 +440,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
@ -535,19 +517,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
yield from self.fallbacks
|
yield from self.fallbacks
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||||
first_error = None
|
first_error = None
|
||||||
@ -577,19 +549,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
|
|
||||||
@ -808,19 +770,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||||
|
|
||||||
@ -846,19 +798,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
|
|
||||||
@ -993,19 +935,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = callback_manager.on_chain_start(dumpd(self), input)
|
||||||
|
|
||||||
@ -1069,19 +1001,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
|
|
||||||
@ -1193,7 +1115,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
local_callbacks=None,
|
local_callbacks=None,
|
||||||
@ -1216,7 +1138,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
step.invoke,
|
step.invoke,
|
||||||
input,
|
input,
|
||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(config, run_manager.get_child()),
|
patch_config(deepcopy(config), run_manager.get_child()),
|
||||||
)
|
)
|
||||||
for step in steps.values()
|
for step in steps.values()
|
||||||
]
|
]
|
||||||
@ -1235,19 +1157,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or {}
|
config = ensure_config(config)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
local_callbacks=None,
|
|
||||||
verbose=False,
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
local_tags=None,
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
local_metadata=None,
|
|
||||||
)
|
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
run_manager = await callback_manager.on_chain_start(dumpd(self), input)
|
||||||
|
|
||||||
@ -1540,14 +1452,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
|
||||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
|
||||||
) -> RunnableConfig:
|
|
||||||
config = config.copy()
|
|
||||||
config["callbacks"] = callback_manager
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def coerce_to_runnable(
|
def coerce_to_runnable(
|
||||||
thing: Union[
|
thing: Union[
|
||||||
Runnable[Input, Output],
|
Runnable[Input, Output],
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, TypedDict
|
from typing import Any, Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from langchain.callbacks.base import Callbacks
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
|
|
||||||
|
|
||||||
class RunnableConfig(TypedDict, total=False):
|
class RunnableConfig(TypedDict, total=False):
|
||||||
@ -25,3 +26,42 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_locals: Dict[str, Any]
|
||||||
|
"""
|
||||||
|
Local variables
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
|
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||||
|
if config is not None:
|
||||||
|
empty.update(config)
|
||||||
|
return empty
|
||||||
|
|
||||||
|
|
||||||
|
def patch_config(
|
||||||
|
config: RunnableConfig,
|
||||||
|
callbacks: BaseCallbackManager,
|
||||||
|
) -> RunnableConfig:
|
||||||
|
config = config.copy()
|
||||||
|
config["callbacks"] = callbacks
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||||
|
return CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_callback_manager_for_config(
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> AsyncCallbackManager:
|
||||||
|
return AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
@ -47,10 +47,11 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
) -> Iterator[Input]:
|
) -> Iterator[Input]:
|
||||||
return self._transform_stream_with_config(input, identity, config)
|
return self._transform_stream_with_config(input, identity, config)
|
||||||
|
|
||||||
def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
input: AsyncIterator[Input],
|
input: AsyncIterator[Input],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Input]:
|
) -> AsyncIterator[Input]:
|
||||||
return self._atransform_stream_with_config(input, identity, config)
|
async for chunk in self._atransform_stream_with_config(input, identity, config):
|
||||||
|
yield chunk
|
||||||
|
@ -1352,6 +1352,7 @@
|
|||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "not_implemented",
|
"type": "not_implemented",
|
||||||
"id": [
|
"id": [
|
||||||
|
"runnable",
|
||||||
"test_runnable",
|
"test_runnable",
|
||||||
"FakeRetriever"
|
"FakeRetriever"
|
||||||
]
|
]
|
@ -0,0 +1,93 @@
|
|||||||
|
from typing import Any, Callable, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain import PromptTemplate
|
||||||
|
from langchain.llms import FakeListLLM
|
||||||
|
from langchain.schema.runnable import (
|
||||||
|
GetLocalVar,
|
||||||
|
PutLocalVar,
|
||||||
|
RunnablePassthrough,
|
||||||
|
RunnableSequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("method", "input", "output"),
|
||||||
|
[
|
||||||
|
(lambda r, x: r.invoke(x), "foo", "foo"),
|
||||||
|
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
|
||||||
|
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_put_get(method: Callable, input: Any, output: Any) -> None:
|
||||||
|
runnable = PutLocalVar("input") | GetLocalVar("input")
|
||||||
|
assert method(runnable, input) == output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("method", "input", "output"),
|
||||||
|
[
|
||||||
|
(lambda r, x: r.ainvoke(x), "foo", "foo"),
|
||||||
|
(lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
|
||||||
|
runnable = PutLocalVar("input") | GetLocalVar("input")
|
||||||
|
assert await method(runnable, input) == output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("runnable", "error"),
|
||||||
|
[
|
||||||
|
(PutLocalVar("input"), ValueError),
|
||||||
|
(GetLocalVar("input"), ValueError),
|
||||||
|
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
|
||||||
|
with pytest.raises(error):
|
||||||
|
runnable.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_in_map() -> None:
|
||||||
|
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
||||||
|
assert runnable.invoke("foo") == {"bar": "foo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_in_map() -> None:
|
||||||
|
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
runnable.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"runnable",
|
||||||
|
[
|
||||||
|
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
|
||||||
|
(
|
||||||
|
PutLocalVar("input")
|
||||||
|
| {"input": RunnablePassthrough()}
|
||||||
|
| PromptTemplate.from_template("say {input}")
|
||||||
|
| FakeListLLM(responses=["hello"])
|
||||||
|
| GetLocalVar("input", passthrough_key="output")
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("method", "input", "output"),
|
||||||
|
[
|
||||||
|
(lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}),
|
||||||
|
(lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]),
|
||||||
|
(
|
||||||
|
lambda r, x: list(r.stream(x))[0],
|
||||||
|
"hello",
|
||||||
|
{"input": "hello", "output": "hello"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_put_get_sequence(
|
||||||
|
runnable: RunnableSequence, method: Callable, input: Any, output: Any
|
||||||
|
) -> None:
|
||||||
|
assert method(runnable, input) == output
|
@ -132,15 +132,24 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||||
) == [5, 7]
|
) == [5, 7]
|
||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call("hello", dict(tags=["a-tag"])),
|
mocker.call(
|
||||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||||
|
),
|
||||||
|
mocker.call(
|
||||||
|
"wooorld",
|
||||||
|
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call("hello", dict(tags=["a-tag"])),
|
mocker.call(
|
||||||
mocker.call("wooorld", dict(tags=["a-tag"])),
|
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||||
|
),
|
||||||
|
mocker.call(
|
||||||
|
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||||
|
),
|
||||||
]
|
]
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
@ -161,8 +170,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
7,
|
7,
|
||||||
]
|
]
|
||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call("hello", dict(metadata={"key": "value"})),
|
mocker.call(
|
||||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
"hello",
|
||||||
|
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||||
|
),
|
||||||
|
mocker.call(
|
||||||
|
"wooorld",
|
||||||
|
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user