add explicit agent end method (#486)

This commit is contained in:
Harrison Chase 2022-12-29 22:23:15 -05:00 committed by GitHub
parent 7e36f28e78
commit 52490e2dcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 39 additions and 22 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, root_validator
@ -16,7 +16,7 @@ from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish
from langchain.schema import AgentAction, AgentFinish
logger = logging.getLogger()
@ -46,7 +46,7 @@ class Agent(BaseModel):
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> AgentAction:
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
@ -73,7 +73,7 @@ class Agent(BaseModel):
parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output
if tool == self.finish_tool_name:
return AgentFinish(tool, tool_input, full_output, {"output": tool_input})
return AgentFinish(full_output, {"output": tool_input})
return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self) -> None:
@ -220,10 +220,7 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.verbose:
self._get_callback_manager().on_tool_start(
{"name": "Finish"}, output, color="green"
)
self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION)
self._get_callback_manager().on_agent_end(output.log, color="green")
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps

View File

@ -54,6 +54,10 @@ class BaseCallbackHandler(ABC):
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
@abstractmethod
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent ends."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
"""Base callback manager that can be used to handle callbacks from LangChain."""
@ -128,6 +132,11 @@ class CallbackManager(BaseCallbackManager):
for handler in self.handlers:
handler.on_tool_error(error)
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent ends."""
for handler in self.handlers:
handler.on_agent_end(log, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)

View File

@ -88,6 +88,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
with self._lock:
self._callback_manager.on_tool_error(error)
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent ends."""
with self._lock:
self._callback_manager.on_agent_end(log, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult
from langchain.schema import AgentAction, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler):
@ -59,11 +59,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
if output != AGENT_FINISH_OBSERVATION:
print_text(f"\n{observation_prefix}")
print_text(output, color=color)
print_text(f"\n{llm_prefix}")
print_text(f"\n{observation_prefix}")
print_text(output, color=color)
print_text(f"\n{llm_prefix}")
def on_tool_error(self, error: Exception) -> None:
"""Do nothing."""
pass
def on_agent_end(
self, log: str, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run when agent ends."""
print_text(log, color=color)

View File

@ -1,13 +1,9 @@
"""Common schema objects."""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
AGENT_FINISH_OBSERVATION = "__agent_finish__"
@dataclass
class AgentAction:
class AgentAction(NamedTuple):
"""Agent's action to take."""
tool: str
@ -15,10 +11,10 @@ class AgentAction:
log: str
@dataclass
class AgentFinish(AgentAction):
class AgentFinish(NamedTuple):
"""Agent's return value."""
log: str
return_values: dict

View File

@ -90,8 +90,9 @@ def test_agent_with_callbacks() -> None:
output = agent.run("when was langchain made")
assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending
assert handler.starts == 7
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7
assert handler.errors == 0

View File

@ -58,3 +58,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.ends += 1