Move config helpers

This commit is contained in:
Nuno Campos 2023-08-18 10:10:35 +01:00
parent a5e7dcec61
commit 8ddaaf3d41
2 changed files with 57 additions and 53 deletions

View File

@ -28,40 +28,18 @@ from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
)
from langchain.schema.runnable.utils import (
gather_with_concurrency,
)
from langchain.utils.aiter import atee, py_anext
def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
if config is not None:
empty.update(config)
return empty
def _get_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def _get_async_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
@ -241,7 +219,7 @@ class Runnable(Generic[Input, Output], ABC):
return (
config
if isinstance(config, list)
else [deepcopy(_ensure_config(config)) for _ in range(length)]
else [deepcopy(ensure_config(config)) for _ in range(length)]
)
def _call_with_config(
@ -253,8 +231,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
@ -283,8 +261,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
@ -322,8 +300,8 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
@ -387,8 +365,8 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
{"input": ""},
@ -462,8 +440,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -495,8 +473,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
# setup callbacks
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -724,8 +702,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -753,8 +731,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
# setup callbacks
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -899,8 +877,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
# setup callbacks
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -966,8 +944,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
# setup callbacks
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -1068,7 +1046,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = _ensure_config(config)
config = ensure_config(config)
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
@ -1108,8 +1086,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
# setup callbacks
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input}

View File

@ -1,8 +1,9 @@
from __future__ import annotations
from typing import Any, Dict, List, TypedDict
from typing import Any, Dict, List, Optional, TypedDict
from langchain.callbacks.base import Callbacks
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager
class RunnableConfig(TypedDict, total=False):
@ -30,3 +31,28 @@ class RunnableConfig(TypedDict, total=False):
"""
Local variables
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
if config is not None:
empty.update(config)
return empty
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def get_async_callback_manager_for_config(
config: RunnableConfig,
) -> AsyncCallbackManager:
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)