From 6cdd4b5edca511b0015f1b39102225fe638d8359 Mon Sep 17 00:00:00 2001 From: Alec Flett Date: Wed, 12 Jul 2023 00:48:29 -0700 Subject: [PATCH] only add handlers if they are new (#7504) When using callbacks, there are times when callbacks can be added redundantly: for instance sometimes you might need to create an llm with specific callbacks, but then also create and agent that uses a chain that has those callbacks already set. This means that "callbacks" might get passed down again to the llm at predict() time, resulting in duplicate calls to the `on_llm_start` callback. For the sake of simplicity, I made it so that langchain never adds an exact handler/callbacks object in `add_handler`, thus avoiding the duplicate handler issue. Tagging @hwchase17 for callback review --------- Co-authored-by: Bagatur --- langchain/callbacks/base.py | 5 ++-- .../callbacks/fake_callback_handler.py | 5 +++- .../callbacks/test_callback_manager.py | 23 ++++++++++++------- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 930d03fbbe9..2e15185f898 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -498,8 +498,9 @@ class BaseCallbackManager(CallbackManagerMixin): def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Add a handler to the callback manager.""" - self.handlers.append(handler) - if inherit: + if handler not in self.handlers: + self.handlers.append(handler) + if inherit and handler not in self.inheritable_handlers: self.inheritable_handlers.append(handler) def remove_handler(self, handler: BaseCallbackHandler) -> None: diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index dc7a8e777bf..a5e3d3ef692 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -1,6 +1,6 @@ """A fake callback handler for testing purposes.""" from itertools import chain -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from uuid import UUID from pydantic import BaseModel @@ -22,6 +22,9 @@ class BaseFakeCallbackHandler(BaseModel): ignore_retriever_: bool = False ignore_chat_model_: bool = False + # to allow for similar callback handlers that are not technicall equal + fake_id: Union[str, None] = None + # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 chain_ends: int = 0 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 8d82fa90a3f..428d41f8c5b 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -178,10 +178,10 @@ async def test_async_callback_manager_sync_handler() -> None: def test_callback_manager_inheritance() -> None: handler1, handler2, handler3, handler4 = ( - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), + FakeCallbackHandler(fake_id="handler1"), + FakeCallbackHandler(fake_id="handler2"), + FakeCallbackHandler(fake_id="handler3"), + FakeCallbackHandler(fake_id="handler4"), ) callback_manager1 = CallbackManager(handlers=[handler1, handler2]) @@ -222,15 +222,22 @@ def test_callback_manager_inheritance() -> None: assert child_manager2.inheritable_handlers == [handler1] +def test_duplicate_callbacks() -> None: + handler = FakeCallbackHandler() + manager = CallbackManager(handlers=[handler]) + manager.add_handler(handler) + assert manager.handlers == [handler] + + def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None: """Test callback manager configuration.""" monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false") monkeypatch.setenv("LANGCHAIN_TRACING", "false") handler1, handler2, handler3, handler4 = ( - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), - FakeCallbackHandler(), + FakeCallbackHandler(fake_id="handler1"), + FakeCallbackHandler(fake_id="handler2"), + FakeCallbackHandler(fake_id="handler3"), + FakeCallbackHandler(fake_id="handler4"), ) inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]