From 75ae585deb4143e2c5f9774c74f10aa963fc0b3d Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Thu, 15 Aug 2024 09:56:31 -0700 Subject: [PATCH] Merge support for group manager (#25360) --- libs/core/langchain_core/callbacks/base.py | 72 ++++++++- libs/core/langchain_core/callbacks/manager.py | 152 ++++++++++++++++-- libs/core/langchain_core/runnables/config.py | 32 +--- .../tests/unit_tests/runnables/test_config.py | 53 +++++- 4 files changed, 258 insertions(+), 51 deletions(-) diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index 7d3c7c0e784..bd817ef097b 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union from uuid import UUID @@ -13,6 +14,8 @@ if TYPE_CHECKING: from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult +_LOGGER = logging.getLogger(__name__) + class RetrieverManagerMixin: """Mixin for Retriever callbacks.""" @@ -911,15 +914,72 @@ class BaseCallbackManager(CallbackManagerMixin): def copy(self: T) -> T: """Copy the callback manager.""" return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, + handlers=self.handlers.copy(), + inheritable_handlers=self.inheritable_handlers.copy(), parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, + tags=self.tags.copy(), + inheritable_tags=self.inheritable_tags.copy(), + metadata=self.metadata.copy(), + inheritable_metadata=self.inheritable_metadata.copy(), ) + def merge(self: T, other: BaseCallbackManager) -> T: + """Merge the callback manager with another callback manager. + + May be overwritten in subclasses. Primarily used internally + within merge_configs. + + Returns: + BaseCallbackManager: The merged callback manager of the same type + as the current object. + + Example: Merging two callback managers. + + .. code-block:: python + + from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group + from langchain_core.callbacks.stdout import StdOutCallbackHandler + + manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"]) + with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager: + merged_manager = group_manager.merge(manager) + print(merged_manager.handlers) + # [ + # , + # , + # ] + + print(merged_manager.tags) + # ['tag2', 'tag1'] + + """ # noqa: E501 + if self.parent_run_id != other.parent_run_id: + _LOGGER.warning( + f"{self.__class__.__name__}.merge(): Parent run IDs do not match." + " Using the parent run ID of the first callback manager." + ) + manager = self.__class__( + parent_run_id=self.parent_run_id or other.parent_run_id, + handlers=[], + inheritable_handlers=[], + tags=list(set(self.tags + other.tags)), + inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)), + metadata={ + **self.metadata, + **other.metadata, + }, + ) + + handlers = self.handlers + other.handlers + inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers + + for handler in handlers: + manager.add_handler(handler) + + for handler in inheritable_handlers: + manager.add_handler(handler, inherit=True) + return manager + @property def is_async(self) -> bool: """Whether the callback manager is async.""" diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 82163ee4067..77b963388c3 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -1612,16 +1612,80 @@ class CallbackManagerForChainGroup(CallbackManager): def copy(self) -> CallbackManagerForChainGroup: """Copy the callback manager.""" return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, + handlers=self.handlers.copy(), + inheritable_handlers=self.inheritable_handlers.copy(), parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, + tags=self.tags.copy(), + inheritable_tags=self.inheritable_tags.copy(), + metadata=self.metadata.copy(), + inheritable_metadata=self.inheritable_metadata.copy(), parent_run_manager=self.parent_run_manager, ) + def merge( + self: CallbackManagerForChainGroup, other: BaseCallbackManager + ) -> CallbackManagerForChainGroup: + """Merge the group callback manager with another callback manager. + + Overwrites the merge method in the base class to ensure that the + parent run manager is preserved. Keeps the parent_run_manager + from the current object. + + Returns: + CallbackManagerForChainGroup: A copy of the current object with the + handlers, tags, and other attributes merged from the other object. + + Example: Merging two callback managers. + + .. code-block:: python + + from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group + from langchain_core.callbacks.stdout import StdOutCallbackHandler + + manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"]) + with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager: + merged_manager = group_manager.merge(manager) + print(type(merged_manager)) + # + + print(merged_manager.handlers) + # [ + # , + # , + # ] + + print(merged_manager.tags) + # ['tag2', 'tag1'] + + """ # noqa: E501 + if self.parent_run_id != other.parent_run_id: + logger.warning( + f"{self.__class__.__name__}.merge(): Parent run IDs do not match." + " Using the parent run ID of the first callback manager." + ) + manager = self.__class__( + parent_run_id=self.parent_run_id or other.parent_run_id, + handlers=[], + inheritable_handlers=[], + tags=list(set(self.tags + other.tags)), + inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)), + metadata={ + **self.metadata, + **other.metadata, + }, + parent_run_manager=self.parent_run_manager, + ) + + handlers = self.handlers + other.handlers + inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers + + for handler in handlers: + manager.add_handler(handler) + + for handler in inheritable_handlers: + manager.add_handler(handler, inherit=True) + return manager + def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: """Run when traced chain group ends. @@ -2040,16 +2104,80 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): def copy(self) -> AsyncCallbackManagerForChainGroup: """Copy the async callback manager.""" return self.__class__( - handlers=self.handlers, - inheritable_handlers=self.inheritable_handlers, + handlers=self.handlers.copy(), + inheritable_handlers=self.inheritable_handlers.copy(), parent_run_id=self.parent_run_id, - tags=self.tags, - inheritable_tags=self.inheritable_tags, - metadata=self.metadata, - inheritable_metadata=self.inheritable_metadata, + tags=self.tags.copy(), + inheritable_tags=self.inheritable_tags.copy(), + metadata=self.metadata.copy(), + inheritable_metadata=self.inheritable_metadata.copy(), parent_run_manager=self.parent_run_manager, ) + def merge( + self: AsyncCallbackManagerForChainGroup, other: BaseCallbackManager + ) -> AsyncCallbackManagerForChainGroup: + """Merge the group callback manager with another callback manager. + + Overwrites the merge method in the base class to ensure that the + parent run manager is preserved. Keeps the parent_run_manager + from the current object. + + Returns: + AsyncCallbackManagerForChainGroup: A copy of the current AsyncCallbackManagerForChainGroup + with the handlers, tags, etc. of the other callback manager merged in. + + Example: Merging two callback managers. + + .. code-block:: python + + from langchain_core.callbacks.manager import CallbackManager, atrace_as_chain_group + from langchain_core.callbacks.stdout import StdOutCallbackHandler + + manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"]) + async with atrace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager: + merged_manager = group_manager.merge(manager) + print(type(merged_manager)) + # + + print(merged_manager.handlers) + # [ + # , + # , + # ] + + print(merged_manager.tags) + # ['tag2', 'tag1'] + + """ # noqa: E501 + if self.parent_run_id != other.parent_run_id: + logger.warning( + f"{self.__class__.__name__}.merge(): Parent run IDs do not match." + " Using the parent run ID of the first callback manager." + ) + manager = self.__class__( + parent_run_id=self.parent_run_id or other.parent_run_id, + handlers=[], + inheritable_handlers=[], + tags=list(set(self.tags + other.tags)), + inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)), + metadata={ + **self.metadata, + **other.metadata, + }, + parent_run_manager=self.parent_run_manager, + ) + + handlers = self.handlers + other.handlers + inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers + + for handler in handlers: + manager.add_handler(handler) + + for handler in inheritable_handlers: + manager.add_handler(handler, inherit=True) + return manager + async def on_chain_end( self, outputs: Union[Dict[str, Any], Any], **kwargs: Any ) -> None: diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index da6aa98a4b6..1d9772bea11 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -348,37 +348,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: base["callbacks"] = mngr else: # base_callbacks is also a manager - manager = base_callbacks.__class__( - parent_run_id=base_callbacks.parent_run_id - or these_callbacks.parent_run_id, - handlers=[], - inheritable_handlers=[], - tags=list(set(base_callbacks.tags + these_callbacks.tags)), - inheritable_tags=list( - set( - base_callbacks.inheritable_tags - + these_callbacks.inheritable_tags - ) - ), - metadata={ - **base_callbacks.metadata, - **these_callbacks.metadata, - }, - ) - - handlers = base_callbacks.handlers + these_callbacks.handlers - inheritable_handlers = ( - base_callbacks.inheritable_handlers - + these_callbacks.inheritable_handlers - ) - - for handler in handlers: - manager.add_handler(handler) - - for handler in inheritable_handlers: - manager.add_handler(handler, inherit=True) - - base["callbacks"] = manager + base["callbacks"] = base_callbacks.merge(these_callbacks) elif key == "recursion_limit": if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT: base["recursion_limit"] = config["recursion_limit"] diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index a8118bbd3bb..23eac0cfd73 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -5,7 +5,12 @@ from typing import Any, Dict, cast import pytest -from langchain_core.callbacks.manager import CallbackManager +from langchain_core.callbacks.manager import ( + AsyncCallbackManager, + CallbackManager, + atrace_as_chain_group, + trace_as_chain_group, +) from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_core.runnables import RunnableBinding, RunnablePassthrough @@ -66,7 +71,7 @@ def test_ensure_config() -> None: } -def test_merge_config_callbacks() -> None: +async def test_merge_config_callbacks() -> None: manager: RunnableConfig = { "callbacks": CallbackManager(handlers=[StdOutCallbackHandler()]) } @@ -94,6 +99,50 @@ def test_merge_config_callbacks() -> None: assert isinstance(merged[0], ConsoleCallbackHandler) assert isinstance(merged[1], StreamingStdOutCallbackHandler) + # Check that the original object wasn't mutated + merged = merge_configs(manager, handlers)["callbacks"] + + assert isinstance(merged, CallbackManager) + assert len(merged.handlers) == 2 + assert isinstance(merged.handlers[0], StdOutCallbackHandler) + assert isinstance(merged.handlers[1], ConsoleCallbackHandler) + + with trace_as_chain_group("test") as gm: + group_manager: RunnableConfig = { + "callbacks": gm, + } + merged = merge_configs(group_manager, handlers)["callbacks"] + assert isinstance(merged, CallbackManager) + assert len(merged.handlers) == 1 + assert isinstance(merged.handlers[0], ConsoleCallbackHandler) + + merged = merge_configs(handlers, group_manager)["callbacks"] + assert isinstance(merged, CallbackManager) + assert len(merged.handlers) == 1 + assert isinstance(merged.handlers[0], ConsoleCallbackHandler) + merged = merge_configs(group_manager, manager)["callbacks"] + assert isinstance(merged, CallbackManager) + assert len(merged.handlers) == 1 + assert isinstance(merged.handlers[0], StdOutCallbackHandler) + + async with atrace_as_chain_group("test_async") as gm: + group_manager = { + "callbacks": gm, + } + merged = merge_configs(group_manager, handlers)["callbacks"] + assert isinstance(merged, AsyncCallbackManager) + assert len(merged.handlers) == 1 + assert isinstance(merged.handlers[0], ConsoleCallbackHandler) + + merged = merge_configs(handlers, group_manager)["callbacks"] + assert isinstance(merged, AsyncCallbackManager) + assert len(merged.handlers) == 1 + assert isinstance(merged.handlers[0], ConsoleCallbackHandler) + merged = merge_configs(group_manager, manager)["callbacks"] + assert isinstance(merged, AsyncCallbackManager) + assert len(merged.handlers) == 1 + assert isinstance(merged.handlers[0], StdOutCallbackHandler) + def test_config_arbitrary_keys() -> None: base: RunnablePassthrough[Any] = RunnablePassthrough()