Move patch_config

This commit is contained in:
Nuno Campos 2023-08-18 10:28:39 +01:00
parent 46f3850794
commit 1baedc4e18
2 changed files with 11 additions and 14 deletions

View File

@ -34,7 +34,6 @@ if TYPE_CHECKING:
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
@ -43,6 +42,7 @@ from langchain.schema.runnable.config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
patch_config,
)
from langchain.schema.runnable.utils import (
accepts_run_manager,
@ -1472,18 +1472,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
yield item
def patch_config(
config: RunnableConfig,
callback_manager: BaseCallbackManager,
_locals: Optional[Dict[str, Any]] = None,
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callback_manager
if _locals is not None:
config["_locals"] = _locals
return config
def coerce_to_runnable(
thing: Union[
Runnable[Input, Output],

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, TypedDict
from langchain.callbacks.base import Callbacks
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
@ -40,6 +40,15 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
return empty
def patch_config(
config: RunnableConfig,
callbacks: BaseCallbackManager,
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callbacks
return config
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),