This commit is contained in:
Harrison Chase 2023-01-03 14:12:43 -08:00
parent 7d0b1cafd7
commit 3efee27e56
7 changed files with 42 additions and 28 deletions

View File

@ -3,12 +3,18 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(ABC): class BaseCallbackHandler(BaseModel, ABC):
"""Base callback handler that can be used to handle callbacks from langchain.""" """Base callback handler that can be used to handle callbacks from langchain."""
ignore_llm: bool = False
ignore_chain: bool = False
ignore_agent: bool = False
@abstractmethod @abstractmethod
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@ -82,16 +88,15 @@ class BaseCallbackManager(BaseCallbackHandler, ABC):
class CallbackManager(BaseCallbackManager): class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain.""" """Callback manager that can be used to handle callbacks from langchain."""
def __init__(self, handlers: List[BaseCallbackHandler]) -> None: handlers: List[BaseCallbackHandler]
"""Initialize the callback manager."""
self.handlers = handlers
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running."""
for handler in self.handlers: for handler in self.handlers:
handler.on_llm_start(serialized, prompts, **kwargs) if not handler.ignore_llm:
handler.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end( def on_llm_end(
self, self,
@ -99,46 +104,54 @@ class CallbackManager(BaseCallbackManager):
) -> None: ) -> None:
"""Run when LLM ends running.""" """Run when LLM ends running."""
for handler in self.handlers: for handler in self.handlers:
handler.on_llm_end(response) if not handler.ignore_llm:
handler.on_llm_end(response)
def on_llm_error(self, error: Exception) -> None: def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
for handler in self.handlers: for handler in self.handlers:
handler.on_llm_error(error) if not handler.ignore_llm:
handler.on_llm_error(error)
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Run when chain starts running.""" """Run when chain starts running."""
for handler in self.handlers: for handler in self.handlers:
handler.on_chain_start(serialized, inputs, **kwargs) if not handler.ignore_chain:
handler.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any]) -> None: def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running.""" """Run when chain ends running."""
for handler in self.handlers: for handler in self.handlers:
handler.on_chain_end(outputs) if not handler.ignore_chain:
handler.on_chain_end(outputs)
def on_chain_error(self, error: Exception) -> None: def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors.""" """Run when chain errors."""
for handler in self.handlers: for handler in self.handlers:
handler.on_chain_error(error) if not handler.ignore_chain:
handler.on_chain_error(error)
def on_tool_start( def on_tool_start(
self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
for handler in self.handlers: for handler in self.handlers:
handler.on_tool_start(serialized, action, **kwargs) if not handler.ignore_agent:
handler.on_tool_start(serialized, action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None: def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.""" """Run when tool ends running."""
for handler in self.handlers: for handler in self.handlers:
handler.on_tool_end(output, **kwargs) if not handler.ignore_agent:
handler.on_tool_end(output, **kwargs)
def on_tool_error(self, error: Exception) -> None: def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors.""" """Run when tool errors."""
for handler in self.handlers: for handler in self.handlers:
handler.on_tool_error(error) if not handler.ignore_agent:
handler.on_tool_error(error)
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on additional input from chains and agents.""" """Run on additional input from chains and agents."""
@ -148,7 +161,8 @@ class CallbackManager(BaseCallbackManager):
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end.""" """Run on agent end."""
for handler in self.handlers: for handler in self.handlers:
handler.on_agent_end(finish, **kwargs) if not handler.ignore_agent:
handler.on_agent_end(finish, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None: def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager.""" """Add a handler to the callback manager."""

View File

@ -32,7 +32,7 @@ class Singleton:
class SharedCallbackManager(Singleton, BaseCallbackManager): class SharedCallbackManager(Singleton, BaseCallbackManager):
"""A thread-safe singleton CallbackManager.""" """A thread-safe singleton CallbackManager."""
_callback_manager: CallbackManager = CallbackManager([]) _callback_manager: CallbackManager = CallbackManager(handlers=[])
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any

View File

@ -71,7 +71,7 @@ def test_agent_with_callbacks_global() -> None:
langchain.verbose = True langchain.verbose = True
handler = FakeCallbackHandler() handler = FakeCallbackHandler()
manager = CallbackManager([handler]) manager = CallbackManager(handlers=[handler])
tool = "Search" tool = "Search"
responses = [ responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
@ -105,7 +105,7 @@ def test_agent_with_callbacks_local() -> None:
langchain.verbose = False langchain.verbose = False
handler = FakeCallbackHandler() handler = FakeCallbackHandler()
manager = CallbackManager([handler]) manager = CallbackManager(handlers=[handler])
tool = "Search" tool = "Search"
responses = [ responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
@ -141,7 +141,7 @@ def test_agent_with_callbacks_not_verbose() -> None:
langchain.verbose = False langchain.verbose = False
handler = FakeCallbackHandler() handler = FakeCallbackHandler()
manager = CallbackManager([handler]) manager = CallbackManager(handlers=[handler])
tool = "Search" tool = "Search"
responses = [ responses = [
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",

View File

@ -8,11 +8,9 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
class FakeCallbackHandler(BaseCallbackHandler): class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing.""" """Fake callback handler for testing."""
def __init__(self) -> None: starts: int = 0
"""Initialize the mock callback handler.""" ends: int = 0
self.starts = 0 errors: int = 0
self.ends = 0
self.errors = 0
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any

View File

@ -29,7 +29,7 @@ def test_callback_manager() -> None:
"""Test the CallbackManager.""" """Test the CallbackManager."""
handler1 = FakeCallbackHandler() handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler() handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2]) manager = CallbackManager(handlers=[handler1, handler2])
_test_callback_manager(manager, handler1, handler2) _test_callback_manager(manager, handler1, handler2)

View File

@ -140,7 +140,9 @@ def test_run_arg_with_memory() -> None:
def test_run_with_callback() -> None: def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed.""" """Test run method works when callback manager is passed."""
handler = FakeCallbackHandler() handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True) chain = FakeChain(
callback_manager=CallbackManager(handlers=[handler]), verbose=True
)
output = chain.run("bar") output = chain.run("bar")
assert output == "baz" assert output == "baz"
assert handler.starts == 1 assert handler.starts == 1
@ -155,7 +157,7 @@ def test_run_with_callback_not_verbose() -> None:
langchain.verbose = False langchain.verbose = False
handler = FakeCallbackHandler() handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager([handler])) chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
output = chain.run("bar") output = chain.run("bar")
assert output == "baz" assert output == "baz"
assert handler.starts == 0 assert handler.starts == 0

View File

@ -7,7 +7,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None: def test_llm_with_callbacks() -> None:
"""Test LLM callbacks.""" """Test LLM callbacks."""
handler = FakeCallbackHandler() handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True) llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
output = llm("foo") output = llm("foo")
assert output == "foo" assert output == "foo"
assert handler.starts == 1 assert handler.starts == 1
@ -22,7 +22,7 @@ def test_llm_with_callbacks_not_verbose() -> None:
langchain.verbose = False langchain.verbose = False
handler = FakeCallbackHandler() handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager([handler])) llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
output = llm("foo") output = llm("foo")
assert output == "foo" assert output == "foo"
assert handler.starts == 0 assert handler.starts == 0