Merge support for group manager (#25360)

This commit is contained in:
William FH
2024-08-15 09:56:31 -07:00
committed by GitHub
parent b7c070d437
commit 75ae585deb
4 changed files with 258 additions and 51 deletions

View File

@@ -5,7 +5,12 @@ from typing import Any, Dict, cast
import pytest
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
CallbackManager,
atrace_as_chain_group,
trace_as_chain_group,
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
@@ -66,7 +71,7 @@ def test_ensure_config() -> None:
}
def test_merge_config_callbacks() -> None:
async def test_merge_config_callbacks() -> None:
manager: RunnableConfig = {
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
}
@@ -94,6 +99,50 @@ def test_merge_config_callbacks() -> None:
assert isinstance(merged[0], ConsoleCallbackHandler)
assert isinstance(merged[1], StreamingStdOutCallbackHandler)
# Check that the original object wasn't mutated
merged = merge_configs(manager, handlers)["callbacks"]
assert isinstance(merged, CallbackManager)
assert len(merged.handlers) == 2
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
assert isinstance(merged.handlers[1], ConsoleCallbackHandler)
with trace_as_chain_group("test") as gm:
group_manager: RunnableConfig = {
"callbacks": gm,
}
merged = merge_configs(group_manager, handlers)["callbacks"]
assert isinstance(merged, CallbackManager)
assert len(merged.handlers) == 1
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
merged = merge_configs(handlers, group_manager)["callbacks"]
assert isinstance(merged, CallbackManager)
assert len(merged.handlers) == 1
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
merged = merge_configs(group_manager, manager)["callbacks"]
assert isinstance(merged, CallbackManager)
assert len(merged.handlers) == 1
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
async with atrace_as_chain_group("test_async") as gm:
group_manager = {
"callbacks": gm,
}
merged = merge_configs(group_manager, handlers)["callbacks"]
assert isinstance(merged, AsyncCallbackManager)
assert len(merged.handlers) == 1
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
merged = merge_configs(handlers, group_manager)["callbacks"]
assert isinstance(merged, AsyncCallbackManager)
assert len(merged.handlers) == 1
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
merged = merge_configs(group_manager, manager)["callbacks"]
assert isinstance(merged, AsyncCallbackManager)
assert len(merged.handlers) == 1
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
def test_config_arbitrary_keys() -> None:
base: RunnablePassthrough[Any] = RunnablePassthrough()