mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 03:51:42 +00:00
callback updates
This commit is contained in:
parent
23b8cfc123
commit
7eb33690a9
@ -259,7 +259,7 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self.callback_manager.on_text(output.log, color="green")
|
self.callback_manager.on_agent_end(output, color="green")
|
||||||
final_output = output.return_values
|
final_output = output.return_values
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
final_output["intermediate_steps"] = intermediate_steps
|
final_output["intermediate_steps"] = intermediate_steps
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.schema import AgentAction, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackHandler(ABC):
|
class BaseCallbackHandler(ABC):
|
||||||
@ -56,7 +56,11 @@ class BaseCallbackHandler(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
"""Run when agent ends."""
|
"""Run on arbitrary text."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||||
@ -141,6 +145,11 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
handler.on_text(text, **kwargs)
|
handler.on_text(text, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_agent_end(finish, **kwargs)
|
||||||
|
|
||||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
"""Add a handler to the callback manager."""
|
"""Add a handler to the callback manager."""
|
||||||
self.handlers.append(handler)
|
self.handlers.append(handler)
|
||||||
|
@ -8,7 +8,7 @@ from langchain.callbacks.base import (
|
|||||||
BaseCallbackManager,
|
BaseCallbackManager,
|
||||||
CallbackManager,
|
CallbackManager,
|
||||||
)
|
)
|
||||||
from langchain.schema import AgentAction, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class Singleton:
|
class Singleton:
|
||||||
@ -89,10 +89,15 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
self._callback_manager.on_tool_error(error)
|
self._callback_manager.on_tool_error(error)
|
||||||
|
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
"""Run when agent ends."""
|
"""Run on arbitrary text."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_text(text, **kwargs)
|
self._callback_manager.on_text(text, **kwargs)
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_agent_end(finish, **kwargs)
|
||||||
|
|
||||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
"""Add a callback to the callback manager."""
|
"""Add a callback to the callback manager."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.schema import AgentAction, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||||
@ -76,3 +76,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run when agent ends."""
|
"""Run when agent ends."""
|
||||||
print_text(text, color=color, end=end)
|
print_text(text, color=color, end=end)
|
||||||
|
|
||||||
|
def on_agent_end(
|
||||||
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
print_text(finish.log, color=color, end="\n")
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class StreamlitCallbackHandler(BaseCallbackHandler):
|
class StreamlitCallbackHandler(BaseCallbackHandler):
|
||||||
@ -70,3 +70,8 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Run on text."""
|
"""Run on text."""
|
||||||
# st.write requires two spaces before a newline to render it
|
# st.write requires two spaces before a newline to render it
|
||||||
st.write(text.replace("\n", " \n"))
|
st.write(text.replace("\n", " \n"))
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run on agent end."""
|
||||||
|
# st.write requires two spaces before a newline to render it
|
||||||
|
st.write(finish.log.replace("\n", " \n"))
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, LLMResult
|
from langchain.schema import AgentAction, LLMResult, AgentFinish
|
||||||
|
|
||||||
|
|
||||||
class FakeCallbackHandler(BaseCallbackHandler):
|
class FakeCallbackHandler(BaseCallbackHandler):
|
||||||
@ -62,3 +62,8 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
"""Run when agent is ending."""
|
"""Run when agent is ending."""
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_agent_end(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent ends running."""
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user