add explicit agent end method (#486)

This commit is contained in:
Harrison Chase 2022-12-29 22:23:15 -05:00 committed by GitHub
parent 7e36f28e78
commit 52490e2dcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 39 additions and 22 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -16,7 +16,7 @@ from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish from langchain.schema import AgentAction, AgentFinish
logger = logging.getLogger() logger = logging.getLogger()
@ -46,7 +46,7 @@ class Agent(BaseModel):
def plan( def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> AgentAction: ) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
Args: Args:
@ -73,7 +73,7 @@ class Agent(BaseModel):
parsed_output = self._extract_tool_and_input(full_output) parsed_output = self._extract_tool_and_input(full_output)
tool, tool_input = parsed_output tool, tool_input = parsed_output
if tool == self.finish_tool_name: if tool == self.finish_tool_name:
return AgentFinish(tool, tool_input, full_output, {"output": tool_input}) return AgentFinish(full_output, {"output": tool_input})
return AgentAction(tool, tool_input, full_output) return AgentAction(tool, tool_input, full_output)
def prepare_for_new_call(self) -> None: def prepare_for_new_call(self) -> None:
@ -220,10 +220,7 @@ class AgentExecutor(Chain, BaseModel):
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
if self.verbose: if self.verbose:
self._get_callback_manager().on_tool_start( self._get_callback_manager().on_agent_end(output.log, color="green")
{"name": "Finish"}, output, color="green"
)
self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION)
final_output = output.return_values final_output = output.return_values
if self.return_intermediate_steps: if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps final_output["intermediate_steps"] = intermediate_steps

View File

@ -54,6 +54,10 @@ class BaseCallbackHandler(ABC):
def on_tool_error(self, error: Exception) -> None: def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors.""" """Run when tool errors."""
@abstractmethod
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent ends."""
class BaseCallbackManager(BaseCallbackHandler, ABC): class BaseCallbackManager(BaseCallbackHandler, ABC):
"""Base callback manager that can be used to handle callbacks from LangChain.""" """Base callback manager that can be used to handle callbacks from LangChain."""
@ -128,6 +132,11 @@ class CallbackManager(BaseCallbackManager):
for handler in self.handlers: for handler in self.handlers:
handler.on_tool_error(error) handler.on_tool_error(error)
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent ends."""
for handler in self.handlers:
handler.on_agent_end(log, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None: def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager.""" """Add a handler to the callback manager."""
self.handlers.append(handler) self.handlers.append(handler)

View File

@ -88,6 +88,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
with self._lock: with self._lock:
self._callback_manager.on_tool_error(error) self._callback_manager.on_tool_error(error)
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent ends."""
with self._lock:
self._callback_manager.on_agent_end(log, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None: def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager.""" """Add a callback to the callback manager."""
with self._lock: with self._lock:

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text from langchain.input import print_text
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult from langchain.schema import AgentAction, LLMResult
class StdOutCallbackHandler(BaseCallbackHandler): class StdOutCallbackHandler(BaseCallbackHandler):
@ -59,7 +59,6 @@ class StdOutCallbackHandler(BaseCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""If not the final action, print out observation.""" """If not the final action, print out observation."""
if output != AGENT_FINISH_OBSERVATION:
print_text(f"\n{observation_prefix}") print_text(f"\n{observation_prefix}")
print_text(output, color=color) print_text(output, color=color)
print_text(f"\n{llm_prefix}") print_text(f"\n{llm_prefix}")
@ -67,3 +66,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
def on_tool_error(self, error: Exception) -> None: def on_tool_error(self, error: Exception) -> None:
"""Do nothing.""" """Do nothing."""
pass pass
def on_agent_end(
self, log: str, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run when agent ends."""
print_text(log, color=color)

View File

@ -1,13 +1,9 @@
"""Common schema objects.""" """Common schema objects."""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional from typing import List, NamedTuple, Optional
AGENT_FINISH_OBSERVATION = "__agent_finish__"
class AgentAction(NamedTuple):
@dataclass
class AgentAction:
"""Agent's action to take.""" """Agent's action to take."""
tool: str tool: str
@ -15,10 +11,10 @@ class AgentAction:
log: str log: str
@dataclass class AgentFinish(NamedTuple):
class AgentFinish(AgentAction):
"""Agent's return value.""" """Agent's return value."""
log: str
return_values: dict return_values: dict

View File

@ -90,8 +90,9 @@ def test_agent_with_callbacks() -> None:
output = agent.run("when was langchain made") output = agent.run("when was langchain made")
assert output == "curses foiled again" assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 7 assert handler.starts == 6
# 1 extra agent end
assert handler.ends == 7 assert handler.ends == 7
assert handler.errors == 0 assert handler.errors == 0

View File

@ -58,3 +58,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
def on_tool_error(self, error: Exception) -> None: def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.errors += 1 self.errors += 1
def on_agent_end(self, log: str, **kwargs: Any) -> None:
"""Run when agent is ending."""
self.ends += 1