diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 930d03fbbe9..2e15185f898 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -498,8 +498,9 @@ class BaseCallbackManager(CallbackManagerMixin): def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Add a handler to the callback manager.""" - self.handlers.append(handler) - if inherit: + if handler not in self.handlers: + self.handlers.append(handler) + if inherit and handler not in self.inheritable_handlers: self.inheritable_handlers.append(handler) def remove_handler(self, handler: BaseCallbackHandler) -> None: diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index dc7a8e777bf..a5e3d3ef692 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -1,6 +1,6 @@ """A fake callback handler for testing purposes.""" from itertools import chain -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from uuid import UUID from pydantic import BaseModel @@ -22,6 +22,9 @@ class BaseFakeCallbackHandler(BaseModel): ignore_retriever_: bool = False ignore_chat_model_: bool = False + # to allow for similar callback handlers that are not technicall equal + fake_id: Union[str, None] = None + # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 chain_ends: int = 0 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 8d82fa90a3f..428d41f8c5b 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -178,10 +178,10 @@ async def test_async_callback_manager_sync_handler() -> None: def test_callback_manager_inheritance() -> None: handler1, handler2, handler3, handler4 = ( - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), + FakeCallbackHandler(fake_id="handler1"), + FakeCallbackHandler(fake_id="handler2"), + FakeCallbackHandler(fake_id="handler3"), + FakeCallbackHandler(fake_id="handler4"), ) callback_manager1 = CallbackManager(handlers=[handler1, handler2]) @@ -222,15 +222,22 @@ def test_callback_manager_inheritance() -> None: assert child_manager2.inheritable_handlers == [handler1] +def test_duplicate_callbacks() -> None: + handler = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler]) + manager.add_handler(handler) + assert manager.handlers == [handler] + + def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None: """Test callback manager configuration.""" monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false") monkeypatch.setenv("LANGCHAIN_TRACING", "false") handler1, handler2, handler3, handler4 = ( - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), + FakeCallbackHandler(fake_id="handler1"), + FakeCallbackHandler(fake_id="handler2"), + FakeCallbackHandler(fake_id="handler3"), + FakeCallbackHandler(fake_id="handler4"), ) inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]