mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 10:13:29 +00:00
Merge support for group manager (#25360)
This commit is contained in:
parent
b7c070d437
commit
75ae585deb
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -13,6 +14,8 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RetrieverManagerMixin:
|
class RetrieverManagerMixin:
|
||||||
"""Mixin for Retriever callbacks."""
|
"""Mixin for Retriever callbacks."""
|
||||||
@ -911,15 +914,72 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
def copy(self: T) -> T:
|
def copy(self: T) -> T:
|
||||||
"""Copy the callback manager."""
|
"""Copy the callback manager."""
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
handlers=self.handlers,
|
handlers=self.handlers.copy(),
|
||||||
inheritable_handlers=self.inheritable_handlers,
|
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
tags=self.tags,
|
tags=self.tags.copy(),
|
||||||
inheritable_tags=self.inheritable_tags,
|
inheritable_tags=self.inheritable_tags.copy(),
|
||||||
metadata=self.metadata,
|
metadata=self.metadata.copy(),
|
||||||
inheritable_metadata=self.inheritable_metadata,
|
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def merge(self: T, other: BaseCallbackManager) -> T:
|
||||||
|
"""Merge the callback manager with another callback manager.
|
||||||
|
|
||||||
|
May be overwritten in subclasses. Primarily used internally
|
||||||
|
within merge_configs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseCallbackManager: The merged callback manager of the same type
|
||||||
|
as the current object.
|
||||||
|
|
||||||
|
Example: Merging two callback managers.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
|
||||||
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
|
||||||
|
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
|
||||||
|
with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
|
||||||
|
merged_manager = group_manager.merge(manager)
|
||||||
|
print(merged_manager.handlers)
|
||||||
|
# [
|
||||||
|
# <langchain_core.callbacks.stdout.StdOutCallbackHandler object at ...>,
|
||||||
|
# <langchain_core.callbacks.streaming_stdout.StreamingStdOutCallbackHandler object at ...>,
|
||||||
|
# ]
|
||||||
|
|
||||||
|
print(merged_manager.tags)
|
||||||
|
# ['tag2', 'tag1']
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
if self.parent_run_id != other.parent_run_id:
|
||||||
|
_LOGGER.warning(
|
||||||
|
f"{self.__class__.__name__}.merge(): Parent run IDs do not match."
|
||||||
|
" Using the parent run ID of the first callback manager."
|
||||||
|
)
|
||||||
|
manager = self.__class__(
|
||||||
|
parent_run_id=self.parent_run_id or other.parent_run_id,
|
||||||
|
handlers=[],
|
||||||
|
inheritable_handlers=[],
|
||||||
|
tags=list(set(self.tags + other.tags)),
|
||||||
|
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
|
||||||
|
metadata={
|
||||||
|
**self.metadata,
|
||||||
|
**other.metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
handlers = self.handlers + other.handlers
|
||||||
|
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
manager.add_handler(handler)
|
||||||
|
|
||||||
|
for handler in inheritable_handlers:
|
||||||
|
manager.add_handler(handler, inherit=True)
|
||||||
|
return manager
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_async(self) -> bool:
|
def is_async(self) -> bool:
|
||||||
"""Whether the callback manager is async."""
|
"""Whether the callback manager is async."""
|
||||||
|
@ -1612,16 +1612,80 @@ class CallbackManagerForChainGroup(CallbackManager):
|
|||||||
def copy(self) -> CallbackManagerForChainGroup:
|
def copy(self) -> CallbackManagerForChainGroup:
|
||||||
"""Copy the callback manager."""
|
"""Copy the callback manager."""
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
handlers=self.handlers,
|
handlers=self.handlers.copy(),
|
||||||
inheritable_handlers=self.inheritable_handlers,
|
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
tags=self.tags,
|
tags=self.tags.copy(),
|
||||||
inheritable_tags=self.inheritable_tags,
|
inheritable_tags=self.inheritable_tags.copy(),
|
||||||
metadata=self.metadata,
|
metadata=self.metadata.copy(),
|
||||||
inheritable_metadata=self.inheritable_metadata,
|
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||||
parent_run_manager=self.parent_run_manager,
|
parent_run_manager=self.parent_run_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self: CallbackManagerForChainGroup, other: BaseCallbackManager
|
||||||
|
) -> CallbackManagerForChainGroup:
|
||||||
|
"""Merge the group callback manager with another callback manager.
|
||||||
|
|
||||||
|
Overwrites the merge method in the base class to ensure that the
|
||||||
|
parent run manager is preserved. Keeps the parent_run_manager
|
||||||
|
from the current object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallbackManagerForChainGroup: A copy of the current object with the
|
||||||
|
handlers, tags, and other attributes merged from the other object.
|
||||||
|
|
||||||
|
Example: Merging two callback managers.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.callbacks.manager import CallbackManager, trace_as_chain_group
|
||||||
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
|
||||||
|
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
|
||||||
|
with trace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
|
||||||
|
merged_manager = group_manager.merge(manager)
|
||||||
|
print(type(merged_manager))
|
||||||
|
# <class 'langchain_core.callbacks.manager.CallbackManagerForChainGroup'>
|
||||||
|
|
||||||
|
print(merged_manager.handlers)
|
||||||
|
# [
|
||||||
|
# <langchain_core.callbacks.stdout.LangChainTracer object at ...>,
|
||||||
|
# <langchain_core.callbacks.streaming_stdout.StdOutCallbackHandler object at ...>,
|
||||||
|
# ]
|
||||||
|
|
||||||
|
print(merged_manager.tags)
|
||||||
|
# ['tag2', 'tag1']
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
if self.parent_run_id != other.parent_run_id:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.__class__.__name__}.merge(): Parent run IDs do not match."
|
||||||
|
" Using the parent run ID of the first callback manager."
|
||||||
|
)
|
||||||
|
manager = self.__class__(
|
||||||
|
parent_run_id=self.parent_run_id or other.parent_run_id,
|
||||||
|
handlers=[],
|
||||||
|
inheritable_handlers=[],
|
||||||
|
tags=list(set(self.tags + other.tags)),
|
||||||
|
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
|
||||||
|
metadata={
|
||||||
|
**self.metadata,
|
||||||
|
**other.metadata,
|
||||||
|
},
|
||||||
|
parent_run_manager=self.parent_run_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
handlers = self.handlers + other.handlers
|
||||||
|
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
manager.add_handler(handler)
|
||||||
|
|
||||||
|
for handler in inheritable_handlers:
|
||||||
|
manager.add_handler(handler, inherit=True)
|
||||||
|
return manager
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
|
||||||
"""Run when traced chain group ends.
|
"""Run when traced chain group ends.
|
||||||
|
|
||||||
@ -2040,16 +2104,80 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
|||||||
def copy(self) -> AsyncCallbackManagerForChainGroup:
|
def copy(self) -> AsyncCallbackManagerForChainGroup:
|
||||||
"""Copy the async callback manager."""
|
"""Copy the async callback manager."""
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
handlers=self.handlers,
|
handlers=self.handlers.copy(),
|
||||||
inheritable_handlers=self.inheritable_handlers,
|
inheritable_handlers=self.inheritable_handlers.copy(),
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
tags=self.tags,
|
tags=self.tags.copy(),
|
||||||
inheritable_tags=self.inheritable_tags,
|
inheritable_tags=self.inheritable_tags.copy(),
|
||||||
metadata=self.metadata,
|
metadata=self.metadata.copy(),
|
||||||
inheritable_metadata=self.inheritable_metadata,
|
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||||
parent_run_manager=self.parent_run_manager,
|
parent_run_manager=self.parent_run_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self: AsyncCallbackManagerForChainGroup, other: BaseCallbackManager
|
||||||
|
) -> AsyncCallbackManagerForChainGroup:
|
||||||
|
"""Merge the group callback manager with another callback manager.
|
||||||
|
|
||||||
|
Overwrites the merge method in the base class to ensure that the
|
||||||
|
parent run manager is preserved. Keeps the parent_run_manager
|
||||||
|
from the current object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncCallbackManagerForChainGroup: A copy of the current AsyncCallbackManagerForChainGroup
|
||||||
|
with the handlers, tags, etc. of the other callback manager merged in.
|
||||||
|
|
||||||
|
Example: Merging two callback managers.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.callbacks.manager import CallbackManager, atrace_as_chain_group
|
||||||
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
|
||||||
|
manager = CallbackManager(handlers=[StdOutCallbackHandler()], tags=["tag2"])
|
||||||
|
async with atrace_as_chain_group("My Group Name", tags=["tag1"]) as group_manager:
|
||||||
|
merged_manager = group_manager.merge(manager)
|
||||||
|
print(type(merged_manager))
|
||||||
|
# <class 'langchain_core.callbacks.manager.AsyncCallbackManagerForChainGroup'>
|
||||||
|
|
||||||
|
print(merged_manager.handlers)
|
||||||
|
# [
|
||||||
|
# <langchain_core.callbacks.stdout.LangChainTracer object at ...>,
|
||||||
|
# <langchain_core.callbacks.streaming_stdout.StdOutCallbackHandler object at ...>,
|
||||||
|
# ]
|
||||||
|
|
||||||
|
print(merged_manager.tags)
|
||||||
|
# ['tag2', 'tag1']
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
if self.parent_run_id != other.parent_run_id:
|
||||||
|
logger.warning(
|
||||||
|
f"{self.__class__.__name__}.merge(): Parent run IDs do not match."
|
||||||
|
" Using the parent run ID of the first callback manager."
|
||||||
|
)
|
||||||
|
manager = self.__class__(
|
||||||
|
parent_run_id=self.parent_run_id or other.parent_run_id,
|
||||||
|
handlers=[],
|
||||||
|
inheritable_handlers=[],
|
||||||
|
tags=list(set(self.tags + other.tags)),
|
||||||
|
inheritable_tags=list(set(self.inheritable_tags + other.inheritable_tags)),
|
||||||
|
metadata={
|
||||||
|
**self.metadata,
|
||||||
|
**other.metadata,
|
||||||
|
},
|
||||||
|
parent_run_manager=self.parent_run_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
handlers = self.handlers + other.handlers
|
||||||
|
inheritable_handlers = self.inheritable_handlers + other.inheritable_handlers
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
manager.add_handler(handler)
|
||||||
|
|
||||||
|
for handler in inheritable_handlers:
|
||||||
|
manager.add_handler(handler, inherit=True)
|
||||||
|
return manager
|
||||||
|
|
||||||
async def on_chain_end(
|
async def on_chain_end(
|
||||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -348,37 +348,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
base["callbacks"] = mngr
|
base["callbacks"] = mngr
|
||||||
else:
|
else:
|
||||||
# base_callbacks is also a manager
|
# base_callbacks is also a manager
|
||||||
manager = base_callbacks.__class__(
|
base["callbacks"] = base_callbacks.merge(these_callbacks)
|
||||||
parent_run_id=base_callbacks.parent_run_id
|
|
||||||
or these_callbacks.parent_run_id,
|
|
||||||
handlers=[],
|
|
||||||
inheritable_handlers=[],
|
|
||||||
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
|
|
||||||
inheritable_tags=list(
|
|
||||||
set(
|
|
||||||
base_callbacks.inheritable_tags
|
|
||||||
+ these_callbacks.inheritable_tags
|
|
||||||
)
|
|
||||||
),
|
|
||||||
metadata={
|
|
||||||
**base_callbacks.metadata,
|
|
||||||
**these_callbacks.metadata,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
handlers = base_callbacks.handlers + these_callbacks.handlers
|
|
||||||
inheritable_handlers = (
|
|
||||||
base_callbacks.inheritable_handlers
|
|
||||||
+ these_callbacks.inheritable_handlers
|
|
||||||
)
|
|
||||||
|
|
||||||
for handler in handlers:
|
|
||||||
manager.add_handler(handler)
|
|
||||||
|
|
||||||
for handler in inheritable_handlers:
|
|
||||||
manager.add_handler(handler, inherit=True)
|
|
||||||
|
|
||||||
base["callbacks"] = manager
|
|
||||||
elif key == "recursion_limit":
|
elif key == "recursion_limit":
|
||||||
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
||||||
base["recursion_limit"] = config["recursion_limit"]
|
base["recursion_limit"] = config["recursion_limit"]
|
||||||
|
@ -5,7 +5,12 @@ from typing import Any, Dict, cast
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import CallbackManager
|
from langchain_core.callbacks.manager import (
|
||||||
|
AsyncCallbackManager,
|
||||||
|
CallbackManager,
|
||||||
|
atrace_as_chain_group,
|
||||||
|
trace_as_chain_group,
|
||||||
|
)
|
||||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
||||||
@ -66,7 +71,7 @@ def test_ensure_config() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_merge_config_callbacks() -> None:
|
async def test_merge_config_callbacks() -> None:
|
||||||
manager: RunnableConfig = {
|
manager: RunnableConfig = {
|
||||||
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
|
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
|
||||||
}
|
}
|
||||||
@ -94,6 +99,50 @@ def test_merge_config_callbacks() -> None:
|
|||||||
assert isinstance(merged[0], ConsoleCallbackHandler)
|
assert isinstance(merged[0], ConsoleCallbackHandler)
|
||||||
assert isinstance(merged[1], StreamingStdOutCallbackHandler)
|
assert isinstance(merged[1], StreamingStdOutCallbackHandler)
|
||||||
|
|
||||||
|
# Check that the original object wasn't mutated
|
||||||
|
merged = merge_configs(manager, handlers)["callbacks"]
|
||||||
|
|
||||||
|
assert isinstance(merged, CallbackManager)
|
||||||
|
assert len(merged.handlers) == 2
|
||||||
|
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
|
||||||
|
assert isinstance(merged.handlers[1], ConsoleCallbackHandler)
|
||||||
|
|
||||||
|
with trace_as_chain_group("test") as gm:
|
||||||
|
group_manager: RunnableConfig = {
|
||||||
|
"callbacks": gm,
|
||||||
|
}
|
||||||
|
merged = merge_configs(group_manager, handlers)["callbacks"]
|
||||||
|
assert isinstance(merged, CallbackManager)
|
||||||
|
assert len(merged.handlers) == 1
|
||||||
|
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
|
||||||
|
|
||||||
|
merged = merge_configs(handlers, group_manager)["callbacks"]
|
||||||
|
assert isinstance(merged, CallbackManager)
|
||||||
|
assert len(merged.handlers) == 1
|
||||||
|
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
|
||||||
|
merged = merge_configs(group_manager, manager)["callbacks"]
|
||||||
|
assert isinstance(merged, CallbackManager)
|
||||||
|
assert len(merged.handlers) == 1
|
||||||
|
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
|
||||||
|
|
||||||
|
async with atrace_as_chain_group("test_async") as gm:
|
||||||
|
group_manager = {
|
||||||
|
"callbacks": gm,
|
||||||
|
}
|
||||||
|
merged = merge_configs(group_manager, handlers)["callbacks"]
|
||||||
|
assert isinstance(merged, AsyncCallbackManager)
|
||||||
|
assert len(merged.handlers) == 1
|
||||||
|
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
|
||||||
|
|
||||||
|
merged = merge_configs(handlers, group_manager)["callbacks"]
|
||||||
|
assert isinstance(merged, AsyncCallbackManager)
|
||||||
|
assert len(merged.handlers) == 1
|
||||||
|
assert isinstance(merged.handlers[0], ConsoleCallbackHandler)
|
||||||
|
merged = merge_configs(group_manager, manager)["callbacks"]
|
||||||
|
assert isinstance(merged, AsyncCallbackManager)
|
||||||
|
assert len(merged.handlers) == 1
|
||||||
|
assert isinstance(merged.handlers[0], StdOutCallbackHandler)
|
||||||
|
|
||||||
|
|
||||||
def test_config_arbitrary_keys() -> None:
|
def test_config_arbitrary_keys() -> None:
|
||||||
base: RunnablePassthrough[Any] = RunnablePassthrough()
|
base: RunnablePassthrough[Any] = RunnablePassthrough()
|
||||||
|
Loading…
Reference in New Issue
Block a user