mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Move config helpers
This commit is contained in:
parent
a5e7dcec61
commit
8ddaaf3d41
@ -28,40 +28,18 @@ from langchain.callbacks.base import BaseCallbackManager
|
|||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
from langchain.schema.runnable.config import RunnableConfig
|
from langchain.schema.runnable.config import (
|
||||||
|
RunnableConfig,
|
||||||
|
ensure_config,
|
||||||
|
get_async_callback_manager_for_config,
|
||||||
|
get_callback_manager_for_config,
|
||||||
|
)
|
||||||
from langchain.schema.runnable.utils import (
|
from langchain.schema.runnable.utils import (
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
)
|
)
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
|
|
||||||
|
|
||||||
def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
|
||||||
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
|
||||||
if config is not None:
|
|
||||||
empty.update(config)
|
|
||||||
return empty
|
|
||||||
|
|
||||||
|
|
||||||
def _get_callback_manager(config: Mapping) -> Any:
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
return CallbackManager.configure(
|
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_async_callback_manager(config: Mapping) -> Any:
|
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
return AsyncCallbackManager.configure(
|
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
|
||||||
inheritable_tags=config.get("tags"),
|
|
||||||
inheritable_metadata=config.get("metadata"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
Input = TypeVar("Input")
|
Input = TypeVar("Input")
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
Output = TypeVar("Output")
|
Output = TypeVar("Output")
|
||||||
@ -241,7 +219,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return (
|
return (
|
||||||
config
|
config
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [deepcopy(_ensure_config(config)) for _ in range(length)]
|
else [deepcopy(ensure_config(config)) for _ in range(length)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
@ -253,8 +231,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input if isinstance(input, dict) else {"input": input},
|
input if isinstance(input, dict) else {"input": input},
|
||||||
@ -283,8 +261,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
input if isinstance(input, dict) else {"input": input},
|
input if isinstance(input, dict) else {"input": input},
|
||||||
@ -322,8 +300,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
@ -387,8 +365,8 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
{"input": ""},
|
{"input": ""},
|
||||||
@ -462,8 +440,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -495,8 +473,8 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -724,8 +702,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -753,8 +731,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -899,8 +877,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -966,8 +944,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||||
@ -1068,7 +1046,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
local_callbacks=None,
|
local_callbacks=None,
|
||||||
@ -1108,8 +1086,8 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = _ensure_config(config)
|
config = ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = get_async_callback_manager_for_config(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self), {"input": input}
|
dumpd(self), {"input": input}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, TypedDict
|
from typing import Any, Dict, List, Optional, TypedDict
|
||||||
|
|
||||||
from langchain.callbacks.base import Callbacks
|
from langchain.callbacks.base import Callbacks
|
||||||
|
from langchain.callbacks.manager import CallbackManager, AsyncCallbackManager
|
||||||
|
|
||||||
|
|
||||||
class RunnableConfig(TypedDict, total=False):
|
class RunnableConfig(TypedDict, total=False):
|
||||||
@ -30,3 +31,28 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
Local variables
|
Local variables
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
|
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||||
|
if config is not None:
|
||||||
|
empty.update(config)
|
||||||
|
return empty
|
||||||
|
|
||||||
|
|
||||||
|
def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager:
|
||||||
|
return CallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_callback_manager_for_config(
|
||||||
|
config: RunnableConfig,
|
||||||
|
) -> AsyncCallbackManager:
|
||||||
|
return AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
|
inheritable_tags=config.get("tags"),
|
||||||
|
inheritable_metadata=config.get("metadata"),
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user