From 7eb33690a9a26a2c0babe01f9606850016ab3f58 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 3 Jan 2023 12:26:52 -0800 Subject: [PATCH] callback updates --- langchain/agents/agent.py | 2 +- langchain/callbacks/base.py | 13 +++++++++++-- langchain/callbacks/shared.py | 9 +++++++-- langchain/callbacks/stdout.py | 8 +++++++- langchain/callbacks/streamlit.py | 7 ++++++- tests/unit_tests/callbacks/fake_callback_handler.py | 7 ++++++- 6 files changed, 38 insertions(+), 8 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 2b67ae0d567..65e9172cd93 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -259,7 +259,7 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: - self.callback_manager.on_text(output.log, color="green") + self.callback_manager.on_agent_end(output, color="green") final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index f8ab7390848..c289fad43cf 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from langchain.schema import AgentAction, LLMResult +from langchain.schema import AgentAction, AgentFinish, LLMResult class BaseCallbackHandler(ABC): @@ -56,7 +56,11 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_text(self, text: str, **kwargs: Any) -> None: - """Run when agent ends.""" + """Run on arbitrary text.""" + + @abstractmethod + def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" class BaseCallbackManager(BaseCallbackHandler, ABC): @@ -141,6 +145,11 @@ class CallbackManager(BaseCallbackManager): for handler in self.handlers: handler.on_text(text, **kwargs) + def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + for handler in self.handlers: + handler.on_agent_end(finish, **kwargs) + def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" self.handlers.append(handler) diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index 763b1de15bb..be1539156b1 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -8,7 +8,7 @@ from langchain.callbacks.base import ( BaseCallbackManager, CallbackManager, ) -from langchain.schema import AgentAction, LLMResult +from langchain.schema import AgentAction, AgentFinish, LLMResult class Singleton: @@ -89,10 +89,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): self._callback_manager.on_tool_error(error) def on_text(self, text: str, **kwargs: Any) -> None: - """Run when agent ends.""" + """Run on arbitrary text.""" with self._lock: self._callback_manager.on_text(text, **kwargs) + def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + with self._lock: + self._callback_manager.on_agent_end(finish, **kwargs) + def add_handler(self, callback: BaseCallbackHandler) -> None: """Add a callback to the callback manager.""" with self._lock: diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index d651a6f53b4..d6ad40d800a 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text -from langchain.schema import AgentAction, LLMResult +from langchain.schema import AgentAction, AgentFinish, LLMResult class StdOutCallbackHandler(BaseCallbackHandler): @@ -76,3 +76,9 @@ class StdOutCallbackHandler(BaseCallbackHandler): ) -> None: """Run when agent ends.""" print_text(text, color=color, end=end) + + def on_agent_end( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run on agent end.""" + print_text(finish.log, color=color, end="\n") diff --git a/langchain/callbacks/streamlit.py b/langchain/callbacks/streamlit.py index 9079b45cace..b56d8aa0668 100644 --- a/langchain/callbacks/streamlit.py +++ b/langchain/callbacks/streamlit.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional import streamlit as st from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, LLMResult +from langchain.schema import AgentAction, AgentFinish, LLMResult class StreamlitCallbackHandler(BaseCallbackHandler): @@ -70,3 +70,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler): """Run on text.""" # st.write requires two spaces before a newline to render it st.write(text.replace("\n", " \n")) + + def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run on agent end.""" + # st.write requires two spaces before a newline to render it + st.write(finish.log.replace("\n", " \n")) diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index fc62aa3212b..03bd07c6c19 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, LLMResult +from langchain.schema import AgentAction, LLMResult, AgentFinish class FakeCallbackHandler(BaseCallbackHandler): @@ -62,3 +62,8 @@ class FakeCallbackHandler(BaseCallbackHandler): def on_text(self, text: str, **kwargs: Any) -> None: """Run when agent is ending.""" self.ends += 1 + + def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.ends += 1 +