mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 01:19:31 +00:00
only add handlers if they are new (#7504)
When using callbacks, there are times when callbacks can be added redundantly: for instance sometimes you might need to create an llm with specific callbacks, but then also create and agent that uses a chain that has those callbacks already set. This means that "callbacks" might get passed down again to the llm at predict() time, resulting in duplicate calls to the `on_llm_start` callback. For the sake of simplicity, I made it so that langchain never adds an exact handler/callbacks object in `add_handler`, thus avoiding the duplicate handler issue. Tagging @hwchase17 for callback review --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
50316f6477
commit
6cdd4b5edc
@ -498,8 +498,9 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
|
|
||||||
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||||
"""Add a handler to the callback manager."""
|
"""Add a handler to the callback manager."""
|
||||||
self.handlers.append(handler)
|
if handler not in self.handlers:
|
||||||
if inherit:
|
self.handlers.append(handler)
|
||||||
|
if inherit and handler not in self.inheritable_handlers:
|
||||||
self.inheritable_handlers.append(handler)
|
self.inheritable_handlers.append(handler)
|
||||||
|
|
||||||
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
def remove_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""A fake callback handler for testing purposes."""
|
"""A fake callback handler for testing purposes."""
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -22,6 +22,9 @@ class BaseFakeCallbackHandler(BaseModel):
|
|||||||
ignore_retriever_: bool = False
|
ignore_retriever_: bool = False
|
||||||
ignore_chat_model_: bool = False
|
ignore_chat_model_: bool = False
|
||||||
|
|
||||||
|
# to allow for similar callback handlers that are not technicall equal
|
||||||
|
fake_id: Union[str, None] = None
|
||||||
|
|
||||||
# add finer-grained counters for easier debugging of failing tests
|
# add finer-grained counters for easier debugging of failing tests
|
||||||
chain_starts: int = 0
|
chain_starts: int = 0
|
||||||
chain_ends: int = 0
|
chain_ends: int = 0
|
||||||
|
@ -178,10 +178,10 @@ async def test_async_callback_manager_sync_handler() -> None:
|
|||||||
|
|
||||||
def test_callback_manager_inheritance() -> None:
|
def test_callback_manager_inheritance() -> None:
|
||||||
handler1, handler2, handler3, handler4 = (
|
handler1, handler2, handler3, handler4 = (
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler1"),
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler2"),
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler3"),
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler4"),
|
||||||
)
|
)
|
||||||
|
|
||||||
callback_manager1 = CallbackManager(handlers=[handler1, handler2])
|
callback_manager1 = CallbackManager(handlers=[handler1, handler2])
|
||||||
@ -222,15 +222,22 @@ def test_callback_manager_inheritance() -> None:
|
|||||||
assert child_manager2.inheritable_handlers == [handler1]
|
assert child_manager2.inheritable_handlers == [handler1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_callbacks() -> None:
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler])
|
||||||
|
manager.add_handler(handler)
|
||||||
|
assert manager.handlers == [handler]
|
||||||
|
|
||||||
|
|
||||||
def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
"""Test callback manager configuration."""
|
"""Test callback manager configuration."""
|
||||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
|
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
|
||||||
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
||||||
handler1, handler2, handler3, handler4 = (
|
handler1, handler2, handler3, handler4 = (
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler1"),
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler2"),
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler3"),
|
||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(fake_id="handler4"),
|
||||||
)
|
)
|
||||||
|
|
||||||
inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]
|
inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]
|
||||||
|
Loading…
Reference in New Issue
Block a user