Move config helpers

This commit is contained in:
Nuno Campos 2023-08-18 10:10:35 +01:00
parent a5e7dcec61
commit 8ddaaf3d41
2 changed files with 57 additions and 53 deletions

View File

@ -28,40 +28,18 @@ from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
)
from langchain.schema.runnable.utils import ( from langchain.schema.runnable.utils import (
gather_with_concurrency, gather_with_concurrency,
) )
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
if config is not None:
empty.update(config)
return empty
def _get_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import CallbackManager
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def _get_async_callback_manager(config: Mapping) -> Any:
from langchain.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
Input = TypeVar("Input") Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do # Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output") Output = TypeVar("Output")
@ -241,7 +219,7 @@ class Runnable(Generic[Input, Output], ABC):
return ( return (
config config
if isinstance(config, list) if isinstance(config, list)
else [deepcopy(_ensure_config(config)) for _ in range(length)] else [deepcopy(ensure_config(config)) for _ in range(length)]
) )
def _call_with_config( def _call_with_config(
@ -253,8 +231,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
input if isinstance(input, dict) else {"input": input}, input if isinstance(input, dict) else {"input": input},
@ -283,8 +261,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses.""" with callbacks. Use this method to implement ainvoke() in subclasses."""
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
input if isinstance(input, dict) else {"input": input}, input if isinstance(input, dict) else {"input": input},
@ -322,8 +300,8 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
@ -387,8 +365,8 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
{"input": ""}, {"input": ""},
@ -462,8 +440,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -495,8 +473,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -724,8 +702,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -753,8 +731,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -899,8 +877,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]: ) -> Iterator[Output]:
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = get_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -966,8 +944,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self), input if isinstance(input, dict) else {"input": input}
@ -1068,7 +1046,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
local_callbacks=None, local_callbacks=None,
@ -1108,8 +1086,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# setup callbacks # setup callbacks
config = _ensure_config(config) config = ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = get_async_callback_manager_for_config(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input} dumpd(self), {"input": input}

View File

@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, TypedDict from typing import Any, Dict, List, Optional, TypedDict
from langchain.callbacks.base import Callbacks from langchain.callbacks.base import Callbacks
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager
class RunnableConfig(TypedDict, total=False): class RunnableConfig(TypedDict, total=False):
@ -30,3 +31,28 @@ class RunnableConfig(TypedDict, total=False):
""" """
Local variables Local variables
""" """
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
if config is not None:
empty.update(config)
return empty
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
return CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
def get_async_callback_manager_for_config(
config: RunnableConfig,
) -> AsyncCallbackManager:
return AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)