mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-26 04:51:45 +00:00
cr
This commit is contained in:
parent
7d0b1cafd7
commit
3efee27e56
@ -3,12 +3,18 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackHandler(ABC):
|
class BaseCallbackHandler(BaseModel, ABC):
|
||||||
"""Base callback handler that can be used to handle callbacks from langchain."""
|
"""Base callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
ignore_llm: bool = False
|
||||||
|
ignore_chain: bool = False
|
||||||
|
ignore_agent: bool = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
@ -82,15 +88,14 @@ class BaseCallbackManager(BaseCallbackHandler, ABC):
|
|||||||
class CallbackManager(BaseCallbackManager):
|
class CallbackManager(BaseCallbackManager):
|
||||||
"""Callback manager that can be used to handle callbacks from langchain."""
|
"""Callback manager that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
|
handlers: List[BaseCallbackHandler]
|
||||||
"""Initialize the callback manager."""
|
|
||||||
self.handlers = handlers
|
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
handler.on_llm_start(serialized, prompts, **kwargs)
|
handler.on_llm_start(serialized, prompts, **kwargs)
|
||||||
|
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
@ -99,11 +104,13 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
handler.on_llm_end(response)
|
handler.on_llm_end(response)
|
||||||
|
|
||||||
def on_llm_error(self, error: Exception) -> None:
|
def on_llm_error(self, error: Exception) -> None:
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_llm:
|
||||||
handler.on_llm_error(error)
|
handler.on_llm_error(error)
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
@ -111,16 +118,19 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
handler.on_chain_start(serialized, inputs, **kwargs)
|
handler.on_chain_start(serialized, inputs, **kwargs)
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
handler.on_chain_end(outputs)
|
handler.on_chain_end(outputs)
|
||||||
|
|
||||||
def on_chain_error(self, error: Exception) -> None:
|
def on_chain_error(self, error: Exception) -> None:
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_chain:
|
||||||
handler.on_chain_error(error)
|
handler.on_chain_error(error)
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
@ -128,16 +138,19 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
handler.on_tool_start(serialized, action, **kwargs)
|
handler.on_tool_start(serialized, action, **kwargs)
|
||||||
|
|
||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
handler.on_tool_end(output, **kwargs)
|
handler.on_tool_end(output, **kwargs)
|
||||||
|
|
||||||
def on_tool_error(self, error: Exception) -> None:
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
handler.on_tool_error(error)
|
handler.on_tool_error(error)
|
||||||
|
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
@ -148,6 +161,7 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
"""Run on agent end."""
|
"""Run on agent end."""
|
||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
|
if not handler.ignore_agent:
|
||||||
handler.on_agent_end(finish, **kwargs)
|
handler.on_agent_end(finish, **kwargs)
|
||||||
|
|
||||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
|
@ -32,7 +32,7 @@ class Singleton:
|
|||||||
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||||
"""A thread-safe singleton CallbackManager."""
|
"""A thread-safe singleton CallbackManager."""
|
||||||
|
|
||||||
_callback_manager: CallbackManager = CallbackManager([])
|
_callback_manager: CallbackManager = CallbackManager(handlers=[])
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
@ -71,7 +71,7 @@ def test_agent_with_callbacks_global() -> None:
|
|||||||
|
|
||||||
langchain.verbose = True
|
langchain.verbose = True
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
manager = CallbackManager([handler])
|
manager = CallbackManager(handlers=[handler])
|
||||||
tool = "Search"
|
tool = "Search"
|
||||||
responses = [
|
responses = [
|
||||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
@ -105,7 +105,7 @@ def test_agent_with_callbacks_local() -> None:
|
|||||||
|
|
||||||
langchain.verbose = False
|
langchain.verbose = False
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
manager = CallbackManager([handler])
|
manager = CallbackManager(handlers=[handler])
|
||||||
tool = "Search"
|
tool = "Search"
|
||||||
responses = [
|
responses = [
|
||||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
@ -141,7 +141,7 @@ def test_agent_with_callbacks_not_verbose() -> None:
|
|||||||
|
|
||||||
langchain.verbose = False
|
langchain.verbose = False
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
manager = CallbackManager([handler])
|
manager = CallbackManager(handlers=[handler])
|
||||||
tool = "Search"
|
tool = "Search"
|
||||||
responses = [
|
responses = [
|
||||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||||
|
@ -8,11 +8,9 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult
|
|||||||
class FakeCallbackHandler(BaseCallbackHandler):
|
class FakeCallbackHandler(BaseCallbackHandler):
|
||||||
"""Fake callback handler for testing."""
|
"""Fake callback handler for testing."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
starts: int = 0
|
||||||
"""Initialize the mock callback handler."""
|
ends: int = 0
|
||||||
self.starts = 0
|
errors: int = 0
|
||||||
self.ends = 0
|
|
||||||
self.errors = 0
|
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
@ -29,7 +29,7 @@ def test_callback_manager() -> None:
|
|||||||
"""Test the CallbackManager."""
|
"""Test the CallbackManager."""
|
||||||
handler1 = FakeCallbackHandler()
|
handler1 = FakeCallbackHandler()
|
||||||
handler2 = FakeCallbackHandler()
|
handler2 = FakeCallbackHandler()
|
||||||
manager = CallbackManager([handler1, handler2])
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
_test_callback_manager(manager, handler1, handler2)
|
_test_callback_manager(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,7 +140,9 @@ def test_run_arg_with_memory() -> None:
|
|||||||
def test_run_with_callback() -> None:
|
def test_run_with_callback() -> None:
|
||||||
"""Test run method works when callback manager is passed."""
|
"""Test run method works when callback manager is passed."""
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True)
|
chain = FakeChain(
|
||||||
|
callback_manager=CallbackManager(handlers=[handler]), verbose=True
|
||||||
|
)
|
||||||
output = chain.run("bar")
|
output = chain.run("bar")
|
||||||
assert output == "baz"
|
assert output == "baz"
|
||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
@ -155,7 +157,7 @@ def test_run_with_callback_not_verbose() -> None:
|
|||||||
langchain.verbose = False
|
langchain.verbose = False
|
||||||
|
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
chain = FakeChain(callback_manager=CallbackManager([handler]))
|
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
|
||||||
output = chain.run("bar")
|
output = chain.run("bar")
|
||||||
assert output == "baz"
|
assert output == "baz"
|
||||||
assert handler.starts == 0
|
assert handler.starts == 0
|
||||||
|
@ -7,7 +7,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
|||||||
def test_llm_with_callbacks() -> None:
|
def test_llm_with_callbacks() -> None:
|
||||||
"""Test LLM callbacks."""
|
"""Test LLM callbacks."""
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True)
|
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
|
||||||
output = llm("foo")
|
output = llm("foo")
|
||||||
assert output == "foo"
|
assert output == "foo"
|
||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
@ -22,7 +22,7 @@ def test_llm_with_callbacks_not_verbose() -> None:
|
|||||||
langchain.verbose = False
|
langchain.verbose = False
|
||||||
|
|
||||||
handler = FakeCallbackHandler()
|
handler = FakeCallbackHandler()
|
||||||
llm = FakeLLM(callback_manager=CallbackManager([handler]))
|
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
|
||||||
output = llm("foo")
|
output = llm("foo")
|
||||||
assert output == "foo"
|
assert output == "foo"
|
||||||
assert handler.starts == 0
|
assert handler.starts == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user