diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 1ca853174a2..5fec1c86ca7 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -34,7 +34,6 @@ if TYPE_CHECKING: ) -from langchain.callbacks.base import BaseCallbackManager from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field @@ -43,6 +42,7 @@ from langchain.schema.runnable.config import ( ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, + patch_config, ) from langchain.schema.runnable.utils import ( accepts_run_manager, @@ -1472,18 +1472,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): yield item -def patch_config( - config: RunnableConfig, - callback_manager: BaseCallbackManager, - _locals: Optional[Dict[str, Any]] = None, -) -> RunnableConfig: - config = config.copy() - config["callbacks"] = callback_manager - if _locals is not None: - config["_locals"] = _locals - return config - - def coerce_to_runnable( thing: Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 716fc361161..00408b7ee6c 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, TypedDict -from langchain.callbacks.base import Callbacks +from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager @@ -40,6 +40,15 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: return empty +def patch_config( + config: RunnableConfig, + callbacks: BaseCallbackManager, +) -> RunnableConfig: + config = config.copy() + config["callbacks"] = callbacks + return config + + def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: return CallbackManager.configure( inheritable_callbacks=config.get("callbacks"),