mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-20 19:41:23 +00:00
callback updates
This commit is contained in:
parent
23b8cfc123
commit
7eb33690a9
@ -259,7 +259,7 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
if self.verbose:
|
||||
self.callback_manager.on_text(output.log, color="green")
|
||||
self.callback_manager.on_agent_end(output, color="green")
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
|
@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class BaseCallbackHandler(ABC):
|
||||
@ -56,7 +56,11 @@ class BaseCallbackHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent ends."""
|
||||
"""Run on arbitrary text."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
|
||||
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
@ -141,6 +145,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
for handler in self.handlers:
|
||||
handler.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
for handler in self.handlers:
|
||||
handler.on_agent_end(finish, **kwargs)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
@ -8,7 +8,7 @@ from langchain.callbacks.base import (
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class Singleton:
|
||||
@ -89,10 +89,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
self._callback_manager.on_tool_error(error)
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent ends."""
|
||||
"""Run on arbitrary text."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_text(text, **kwargs)
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_end(finish, **kwargs)
|
||||
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a callback to the callback manager."""
|
||||
with self._lock:
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
@ -76,3 +76,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text(text, color=color, end=end)
|
||||
|
||||
def on_agent_end(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text(finish.log, color=color, end="\n")
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
import streamlit as st
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
@ -70,3 +70,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||
"""Run on text."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
st.write(text.replace("\n", " \n"))
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
# st.write requires two spaces before a newline to render it
|
||||
st.write(finish.log.replace("\n", " \n"))
|
||||
|
@ -2,7 +2,7 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
from langchain.schema import AgentAction, LLMResult, AgentFinish
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler):
|
||||
@ -62,3 +62,8 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.ends += 1
|
||||
|
||||
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.ends += 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user