mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
Move patch_config
This commit is contained in:
parent
46f3850794
commit
1baedc4e18
@ -34,7 +34,6 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
@ -43,6 +42,7 @@ from langchain.schema.runnable.config import (
|
|||||||
ensure_config,
|
ensure_config,
|
||||||
get_async_callback_manager_for_config,
|
get_async_callback_manager_for_config,
|
||||||
get_callback_manager_for_config,
|
get_callback_manager_for_config,
|
||||||
|
patch_config,
|
||||||
)
|
)
|
||||||
from langchain.schema.runnable.utils import (
|
from langchain.schema.runnable.utils import (
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
@ -1472,18 +1472,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
|
||||||
config: RunnableConfig,
|
|
||||||
callback_manager: BaseCallbackManager,
|
|
||||||
_locals: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> RunnableConfig:
|
|
||||||
config = config.copy()
|
|
||||||
config["callbacks"] = callback_manager
|
|
||||||
if _locals is not None:
|
|
||||||
config["_locals"] = _locals
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def coerce_to_runnable(
|
def coerce_to_runnable(
|
||||||
thing: Union[
|
thing: Union[
|
||||||
Runnable[Input, Output],
|
Runnable[Input, Output],
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional, TypedDict
|
from typing import Any, Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from langchain.callbacks.base import Callbacks
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||||
|
|
||||||
|
|
||||||
@ -40,6 +40,15 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
return empty
|
return empty
|
||||||
|
|
||||||
|
|
||||||
|
def patch_config(
|
||||||
|
config: RunnableConfig,
|
||||||
|
callbacks: BaseCallbackManager,
|
||||||
|
) -> RunnableConfig:
|
||||||
|
config = config.copy()
|
||||||
|
config["callbacks"] = callbacks
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||||
return CallbackManager.configure(
|
return CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
Loading…
Reference in New Issue
Block a user