Move patch_config

This commit is contained in:
Nuno Campos 2023-08-18 10:28:39 +01:00
parent 46f3850794
commit 1baedc4e18
2 changed files with 11 additions and 14 deletions

View File

@ -34,7 +34,6 @@ if TYPE_CHECKING:
) )
from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
@ -43,6 +42,7 @@ from langchain.schema.runnable.config import (
ensure_config, ensure_config,
get_async_callback_manager_for_config, get_async_callback_manager_for_config,
get_callback_manager_for_config, get_callback_manager_for_config,
patch_config,
) )
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
accepts_run_manager, accepts_run_manager,
@ -1472,18 +1472,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
yield item yield item
def patch_config(
config: RunnableConfig,
callback_manager: BaseCallbackManager,
_locals: Optional[Dict[str, Any]] = None,
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callback_manager
if _locals is not None:
config["_locals"] = _locals
return config
def coerce_to_runnable( def coerce_to_runnable(
thing: Union[ thing: Union[
Runnable[Input, Output], Runnable[Input, Output],

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, TypedDict from typing import Any, Dict, List, Optional, TypedDict
from langchain.callbacks.base import Callbacks from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
@ -40,6 +40,15 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
return empty return empty
def patch_config(
config: RunnableConfig,
callbacks: BaseCallbackManager,
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callbacks
return config
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
return CallbackManager.configure( return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),