callback updates

This commit is contained in:
Harrison Chase 2023-01-03 12:26:52 -08:00
parent 23b8cfc123
commit 7eb33690a9
6 changed files with 38 additions and 8 deletions

View File

@ -259,7 +259,7 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
if self.verbose: if self.verbose:
self.callback_manager.on_text(output.log, color="green") self.callback_manager.on_agent_end(output, color="green")
final_output = output.return_values final_output = output.return_values
if self.return_intermediate_steps: if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps final_output["intermediate_steps"] = intermediate_steps

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.schema import AgentAction, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(ABC): class BaseCallbackHandler(ABC):
@ -56,7 +56,11 @@ class BaseCallbackHandler(ABC):
@abstractmethod @abstractmethod
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent ends.""" """Run on arbitrary text."""
@abstractmethod
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
class BaseCallbackManager(BaseCallbackHandler, ABC): class BaseCallbackManager(BaseCallbackHandler, ABC):
@ -141,6 +145,11 @@ class CallbackManager(BaseCallbackManager):
for handler in self.handlers: for handler in self.handlers:
handler.on_text(text, **kwargs) handler.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
for handler in self.handlers:
handler.on_agent_end(finish, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None: def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager.""" """Add a handler to the callback manager."""
self.handlers.append(handler) self.handlers.append(handler)

View File

@ -8,7 +8,7 @@ from langchain.callbacks.base import (
BaseCallbackManager, BaseCallbackManager,
CallbackManager, CallbackManager,
) )
from langchain.schema import AgentAction, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
class Singleton: class Singleton:
@ -89,10 +89,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
self._callback_manager.on_tool_error(error) self._callback_manager.on_tool_error(error)
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent ends.""" """Run on arbitrary text."""
with self._lock: with self._lock:
self._callback_manager.on_text(text, **kwargs) self._callback_manager.on_text(text, **kwargs)
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
with self._lock:
self._callback_manager.on_agent_end(finish, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None: def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager.""" """Add a callback to the callback manager."""
with self._lock: with self._lock:

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text from langchain.input import print_text
from langchain.schema import AgentAction, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler): class StdOutCallbackHandler(BaseCallbackHandler):
@ -76,3 +76,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Run when agent ends.""" """Run when agent ends."""
print_text(text, color=color, end=end) print_text(text, color=color, end=end)
def on_agent_end(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color, end="\n")

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import streamlit as st import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
class StreamlitCallbackHandler(BaseCallbackHandler): class StreamlitCallbackHandler(BaseCallbackHandler):
@ -70,3 +70,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
"""Run on text.""" """Run on text."""
# st.write requires two spaces before a newline to render it # st.write requires two spaces before a newline to render it
st.write(text.replace("\n", " \n")) st.write(text.replace("\n", " \n"))
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
# st.write requires two spaces before a newline to render it
st.write(finish.log.replace("\n", " \n"))

View File

@ -2,7 +2,7 @@
from typing import Any, Dict, List from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult from langchain.schema import AgentAction, LLMResult, AgentFinish
class FakeCallbackHandler(BaseCallbackHandler): class FakeCallbackHandler(BaseCallbackHandler):
@ -62,3 +62,8 @@ class FakeCallbackHandler(BaseCallbackHandler):
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
"""Run when agent is ending.""" """Run when agent is ending."""
self.ends += 1 self.ends += 1
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.ends += 1