From 3efee27e562bfba4ad1e6b2d9846340c26895731 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 3 Jan 2023 14:12:43 -0800 Subject: [PATCH] cr --- langchain/callbacks/base.py | 42 ++++++++++++------- langchain/callbacks/shared.py | 2 +- tests/unit_tests/agents/test_agent.py | 6 +-- .../callbacks/fake_callback_handler.py | 8 ++-- .../callbacks/test_callback_manager.py | 2 +- tests/unit_tests/chains/test_base.py | 6 ++- tests/unit_tests/llms/test_callbacks.py | 4 +- 7 files changed, 42 insertions(+), 28 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index c289fad43cf..fc2f73255fe 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -3,12 +3,18 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List +from pydantic import BaseModel + from langchain.schema import AgentAction, AgentFinish, LLMResult -class BaseCallbackHandler(ABC): +class BaseCallbackHandler(BaseModel, ABC): """Base callback handler that can be used to handle callbacks from langchain.""" + ignore_llm: bool = False + ignore_chain: bool = False + ignore_agent: bool = False + @abstractmethod def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any @@ -82,16 +88,15 @@ class BaseCallbackManager(BaseCallbackHandler, ABC): class CallbackManager(BaseCallbackManager): """Callback manager that can be used to handle callbacks from langchain.""" - def __init__(self, handlers: List[BaseCallbackHandler]) -> None: - """Initialize the callback manager.""" - self.handlers = handlers + handlers: List[BaseCallbackHandler] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" for handler in self.handlers: - handler.on_llm_start(serialized, prompts, **kwargs) + if not handler.ignore_llm: + handler.on_llm_start(serialized, prompts, **kwargs) def on_llm_end( self, @@ -99,46 +104,54 @@ class CallbackManager(BaseCallbackManager): ) -> None: """Run when LLM ends running.""" for handler in self.handlers: - handler.on_llm_end(response) + if not handler.ignore_llm: + handler.on_llm_end(response) def on_llm_error(self, error: Exception) -> None: """Run when LLM errors.""" for handler in self.handlers: - handler.on_llm_error(error) + if not handler.ignore_llm: + handler.on_llm_error(error) def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" for handler in self.handlers: - handler.on_chain_start(serialized, inputs, **kwargs) + if not handler.ignore_chain: + handler.on_chain_start(serialized, inputs, **kwargs) def on_chain_end(self, outputs: Dict[str, Any]) -> None: """Run when chain ends running.""" for handler in self.handlers: - handler.on_chain_end(outputs) + if not handler.ignore_chain: + handler.on_chain_end(outputs) def on_chain_error(self, error: Exception) -> None: """Run when chain errors.""" for handler in self.handlers: - handler.on_chain_error(error) + if not handler.ignore_chain: + handler.on_chain_error(error) def on_tool_start( self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" for handler in self.handlers: - handler.on_tool_start(serialized, action, **kwargs) + if not handler.ignore_agent: + handler.on_tool_start(serialized, action, **kwargs) def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" for handler in self.handlers: - handler.on_tool_end(output, **kwargs) + if not handler.ignore_agent: + handler.on_tool_end(output, **kwargs) def on_tool_error(self, error: Exception) -> None: """Run when tool errors.""" for handler in self.handlers: - handler.on_tool_error(error) + if not handler.ignore_agent: + handler.on_tool_error(error) def on_text(self, text: str, **kwargs: Any) -> None: """Run on additional input from chains and agents.""" @@ -148,7 +161,8 @@ class CallbackManager(BaseCallbackManager): def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: """Run on agent end.""" for handler in self.handlers: - handler.on_agent_end(finish, **kwargs) + if not handler.ignore_agent: + handler.on_agent_end(finish, **kwargs) def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index be1539156b1..93d49ced568 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -32,7 +32,7 @@ class Singleton: class SharedCallbackManager(Singleton, BaseCallbackManager): """A thread-safe singleton CallbackManager.""" - _callback_manager: CallbackManager = CallbackManager([]) + _callback_manager: CallbackManager = CallbackManager(handlers=[]) def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 5a25d62599a..4816fa90d93 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -71,7 +71,7 @@ def test_agent_with_callbacks_global() -> None: langchain.verbose = True handler = FakeCallbackHandler() - manager = CallbackManager([handler]) + manager = CallbackManager(handlers=[handler]) tool = "Search" responses = [ f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", @@ -105,7 +105,7 @@ def test_agent_with_callbacks_local() -> None: langchain.verbose = False handler = FakeCallbackHandler() - manager = CallbackManager([handler]) + manager = CallbackManager(handlers=[handler]) tool = "Search" responses = [ f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", @@ -141,7 +141,7 @@ def test_agent_with_callbacks_not_verbose() -> None: langchain.verbose = False handler = FakeCallbackHandler() - manager = CallbackManager([handler]) + manager = CallbackManager(handlers=[handler]) tool = "Search" responses = [ f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 77da5a63c91..b757cb28719 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -8,11 +8,9 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult class FakeCallbackHandler(BaseCallbackHandler): """Fake callback handler for testing.""" - def __init__(self) -> None: - """Initialize the mock callback handler.""" - self.starts = 0 - self.ends = 0 - self.errors = 0 + starts: int = 0 + ends: int = 0 + errors: int = 0 def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 915bf5808ae..d0927e15aa0 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -29,7 +29,7 @@ def test_callback_manager() -> None: """Test the CallbackManager.""" handler1 = FakeCallbackHandler() handler2 = FakeCallbackHandler() - manager = CallbackManager([handler1, handler2]) + manager = CallbackManager(handlers=[handler1, handler2]) _test_callback_manager(manager, handler1, handler2) diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 8d7775f35ad..c07f857631e 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -140,7 +140,9 @@ def test_run_arg_with_memory() -> None: def test_run_with_callback() -> None: """Test run method works when callback manager is passed.""" handler = FakeCallbackHandler() - chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True) + chain = FakeChain( + callback_manager=CallbackManager(handlers=[handler]), verbose=True + ) output = chain.run("bar") assert output == "baz" assert handler.starts == 1 @@ -155,7 +157,7 @@ def test_run_with_callback_not_verbose() -> None: langchain.verbose = False handler = FakeCallbackHandler() - chain = FakeChain(callback_manager=CallbackManager([handler])) + chain = FakeChain(callback_manager=CallbackManager(handlers=[handler])) output = chain.run("bar") assert output == "baz" assert handler.starts == 0 diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index 997bafd3da2..d9d52630b7f 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -7,7 +7,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM def test_llm_with_callbacks() -> None: """Test LLM callbacks.""" handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True) + llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True) output = llm("foo") assert output == "foo" assert handler.starts == 1 @@ -22,7 +22,7 @@ def test_llm_with_callbacks_not_verbose() -> None: langchain.verbose = False handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager([handler])) + llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler])) output = llm("foo") assert output == "foo" assert handler.starts == 0