This commit is contained in:
Bagatur 2023-08-09 13:26:09 -07:00
parent 539672a7fd
commit 50b13ab938
3 changed files with 60 additions and 114 deletions

View File

@ -62,7 +62,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return self(input, **(config or {}), **kwargs) _config: Dict[str, Any] = dict(config) if config else {}
_config.pop("_locals", None)
return self(input, **_config, **kwargs)
async def ainvoke( async def ainvoke(
self, self,
@ -76,7 +78,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
None, partial(self.invoke, input, config, **kwargs) None, partial(self.invoke, input, config, **kwargs)
) )
return await self.acall(input, **(config or {}), **kwargs) _config: Dict[str, Any] = dict(config) if config else {}
_config.pop("_locals", None)
return await self.acall(input, **_config, **kwargs)
memory: Optional[BaseMemory] = None memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None. """Optional memory object. Defaults to None.

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from itertools import tee from itertools import tee
from typing import ( from typing import (
Any, Any,
@ -66,6 +67,35 @@ class RunnableConfig(TypedDict, total=False):
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
""" """
_locals: Dict[str, Any]
"""
Local variables
"""
def _empty_config() -> RunnableConfig:
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
def _get_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def _get_async_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
Input = TypeVar("Input") Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do # Output type should implement __concat__, as eg str, list, dict do
@ -243,7 +273,7 @@ class Runnable(Generic[Input, Output], ABC):
return ( return (
config config
if isinstance(config, list) if isinstance(config, list)
else [config.copy() if config is not None else {} for _ in range(length)] else [deepcopy(config) if config is not None else {} for _ in range(length)]
) )
def _call_with_config( def _call_with_config(
@ -255,14 +285,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
from langchain.callbacks.manager import CallbackManager
config = config or {} config = config or {}
callback_manager = CallbackManager.configure( callback_manager = _get_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
input if isinstance(input, dict) else {"input": input}, input if isinstance(input, dict) else {"input": input},
@ -288,14 +312,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses.""" with callbacks. Use this method to implement ainvoke() in subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
config = config or {} config = config or {}
callback_manager = AsyncCallbackManager.configure( callback_manager = _get_async_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
input if isinstance(input, dict) else {"input": input}, input if isinstance(input, dict) else {"input": input},
@ -322,8 +340,6 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Iterator of Input values into an Iterator of """Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks. Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses.""" Use this to implement `stream()` or `transform()` in Runnable subclasses."""
from langchain.callbacks.manager import CallbackManager
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = tee(input, 2) input_for_tracing, input_for_transform = tee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one # Start the input iterator to ensure the input runnable starts before this one
@ -333,11 +349,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output_supported = True final_output_supported = True
config = config or {} config = config or {}
callback_manager = CallbackManager.configure( callback_manager = _get_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
@ -393,8 +405,6 @@ class Runnable(Generic[Input, Output], ABC):
"""Helper method to transform an Async Iterator of Input values into an Async """Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks. Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
from langchain.callbacks.manager import AsyncCallbackManager
# tee the input so we can iterate over it twice # tee the input so we can iterate over it twice
input_for_tracing, input_for_transform = atee(input, 2) input_for_tracing, input_for_transform = atee(input, 2)
# Start the input iterator to ensure the input runnable starts before this one # Start the input iterator to ensure the input runnable starts before this one
@ -404,11 +414,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output_supported = True final_output_supported = True
config = config or {} config = config or {}
callback_manager = AsyncCallbackManager.configure( callback_manager = _get_async_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
@ -473,19 +479,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
yield from self.fallbacks yield from self.fallbacks
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or {}
callback_manager = CallbackManager.configure( callback_manager = _get_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -516,19 +512,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or {}
callback_manager = AsyncCallbackManager.configure( callback_manager = _get_async_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -751,19 +737,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) )
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or _empty_config()
callback_manager = CallbackManager.configure( callback_manager = _get_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -771,11 +747,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke all steps in sequence # invoke all steps in sequence
try: try:
callbacks = run_manager.get_child()
for step in self.steps: for step in self.steps:
input = step.invoke( input = step.invoke(
input, input,
# mark each step as a child run # mark each step as a child run
_patch_config(config, run_manager.get_child()), _patch_config(config, callbacks),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -790,19 +767,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or {}
callback_manager = AsyncCallbackManager.configure( callback_manager = _get_async_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -946,19 +913,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def stream( def stream(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]: ) -> Iterator[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or {}
callback_manager = CallbackManager.configure( callback_manager = _get_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -1023,19 +980,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
async def astream( async def astream(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or {}
callback_manager = AsyncCallbackManager.configure( callback_manager = _get_async_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -1173,19 +1120,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks # setup callbacks
config = config or {} config = config or {}
callback_manager = AsyncCallbackManager.configure( callback_manager = _get_async_callback_manager(config)
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input} dumpd(self), {"input": input}
@ -1464,10 +1401,11 @@ class RouterRunnable(
def _patch_config( def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None
) -> RunnableConfig: ) -> RunnableConfig:
config = config.copy() config = deepcopy(config)
config["callbacks"] = callback_manager config["callbacks"] = callback_manager
config["_locals"] = _locals or {}
return config return config

View File

@ -636,7 +636,9 @@ async def _arun_chain(
else: else:
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, _locals={}
)
output = await chain.ainvoke(inputs_, config=runnable_config) output = await chain.ainvoke(inputs_, config=runnable_config)
return output return output
@ -957,7 +959,9 @@ def _run_chain(
else: else:
output = chain(inputs_, callbacks=callbacks, tags=tags) output = chain(inputs_, callbacks=callbacks, tags=tags)
else: else:
runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) runnable_config = RunnableConfig(
tags=tags or [], callbacks=callbacks, _locals={}
)
output = chain.invoke(inputs_, config=runnable_config) output = chain.invoke(inputs_, config=runnable_config)
return output return output