mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
Move config helpers
This commit is contained in:
parent
a5e7dcec61
commit
8ddaaf3d41
@ -28,40 +28,18 @@ from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
get_callback_manager_for_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import (
|
||||
gather_with_concurrency,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
|
||||
|
||||
def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||
if config is not None:
|
||||
empty.update(config)
|
||||
return empty
|
||||
|
||||
|
||||
def _get_callback_manager(config: Mapping) -> Any:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
return CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def _get_async_callback_manager(config: Mapping) -> Any:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
return AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
@ -241,7 +219,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return (
|
||||
config
|
||||
if isinstance(config, list)
|
||||
else [deepcopy(_ensure_config(config)) for _ in range(length)]
|
||||
else [deepcopy(ensure_config(config)) for _ in range(length)]
|
||||
)
|
||||
|
||||
def _call_with_config(
|
||||
@ -253,8 +231,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
@ -283,8 +261,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
@ -322,8 +300,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
@ -387,8 +365,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
@ -462,8 +440,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -495,8 +473,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -724,8 +702,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -753,8 +731,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -899,8 +877,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -966,8 +944,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
@ -1068,7 +1046,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
@ -1108,8 +1086,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
# setup callbacks
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), {"input": input}
|
||||
|
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
@ -30,3 +31,28 @@ class RunnableConfig(TypedDict, total=False):
|
||||
"""
|
||||
Local variables
|
||||
"""
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||
if config is not None:
|
||||
empty.update(config)
|
||||
return empty
|
||||
|
||||
|
||||
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||
return CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
|
||||
def get_async_callback_manager_for_config(
|
||||
config: RunnableConfig,
|
||||
) -> AsyncCallbackManager:
|
||||
return AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
inheritable_tags=config.get("tags"),
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user