diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 0d9df2baeea..5a1d5b29e4f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -28,40 +28,18 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field -from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.config import ( + RunnableConfig, + ensure_config, + get_async_callback_manager_for_config, + get_callback_manager_for_config, +) from langchain.schema.runnable.utils import ( gather_with_concurrency, ) from langchain.utils.aiter import atee, py_anext -def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: - empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) - if config is not None: - empty.update(config) - return empty - - -def _get_callback_manager(config: Mapping) -> Any: - from langchain.callbacks.manager import CallbackManager - - return CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) - - -def _get_async_callback_manager(config: Mapping) -> Any: - from langchain.callbacks.manager import AsyncCallbackManager - - return AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) - - Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do Output = TypeVar("Output") @@ -241,7 +219,7 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [deepcopy(_ensure_config(config)) for _ in range(length)] + else [deepcopy(ensure_config(config)) for _ in range(length)] ) def _call_with_config( @@ -253,8 +231,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -283,8 +261,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -322,8 +300,8 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -387,8 +365,8 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -462,8 +440,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -495,8 +473,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -724,8 +702,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -753,8 +731,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -899,8 +877,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_callback_manager(config) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -966,8 +944,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -1068,7 +1046,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = _ensure_config(config) + config = ensure_config(config) callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1108,8 +1086,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: # setup callbacks - config = _ensure_config(config) - callback_manager = _get_async_callback_manager(config) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": input} diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index f2bf28fcb57..cd620077e1f 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, List, Optional, TypedDict from langchain.callbacks.base import Callbacks +from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager class RunnableConfig(TypedDict, total=False): @@ -30,3 +31,28 @@ class RunnableConfig(TypedDict, total=False): """ Local variables """ + + +def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: + empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + if config is not None: + empty.update(config) + return empty + + +def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: + return CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + + +def get_async_callback_manager_for_config( + config: RunnableConfig, +) -> AsyncCallbackManager: + return AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + )