From 52490e2dcdca85d2e92c6abca1254ea3e229dff1 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 29 Dec 2022 22:23:15 -0500 Subject: [PATCH] add explicit agent end method (#486) --- langchain/agents/agent.py | 13 +++++-------- langchain/callbacks/base.py | 9 +++++++++ langchain/callbacks/shared.py | 5 +++++ langchain/callbacks/stdout.py | 15 ++++++++++----- langchain/schema.py | 10 +++------- tests/unit_tests/agents/test_agent.py | 5 +++-- .../unit_tests/callbacks/fake_callback_handler.py | 4 ++++ 7 files changed, 39 insertions(+), 22 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 5a946527d0d..d6f6db29870 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from abc import abstractmethod -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, root_validator @@ -16,7 +16,7 @@ from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish +from langchain.schema import AgentAction, AgentFinish logger = logging.getLogger() @@ -46,7 +46,7 @@ class Agent(BaseModel): def plan( self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any - ) -> AgentAction: + ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: @@ -73,7 +73,7 @@ class Agent(BaseModel): parsed_output = self._extract_tool_and_input(full_output) tool, tool_input = parsed_output if tool == self.finish_tool_name: - return AgentFinish(tool, tool_input, full_output, {"output": tool_input}) + return AgentFinish(full_output, {"output": tool_input}) return AgentAction(tool, tool_input, full_output) def prepare_for_new_call(self) -> None: @@ -220,10 +220,7 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: - self._get_callback_manager().on_tool_start( - {"name": "Finish"}, output, color="green" - ) - self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION) + self._get_callback_manager().on_agent_end(output.log, color="green") final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index ca0539b1ff2..fa036016708 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -54,6 +54,10 @@ class BaseCallbackHandler(ABC): def on_tool_error(self, error: Exception) -> None: """Run when tool errors.""" + @abstractmethod + def on_agent_end(self, log: str, **kwargs: Any) -> None: + """Run when agent ends.""" + class BaseCallbackManager(BaseCallbackHandler, ABC): """Base callback manager that can be used to handle callbacks from LangChain.""" @@ -128,6 +132,11 @@ class CallbackManager(BaseCallbackManager): for handler in self.handlers: handler.on_tool_error(error) + def on_agent_end(self, log: str, **kwargs: Any) -> None: + """Run when agent ends.""" + for handler in self.handlers: + handler.on_agent_end(log, **kwargs) + def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" self.handlers.append(handler) diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index 10225c7769d..006df55a824 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -88,6 +88,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_tool_error(error) + def on_agent_end(self, log: str, **kwargs: Any) -> None: + """Run when agent ends.""" + with self._lock: + self._callback_manager.on_agent_end(log, **kwargs) + def add_handler(self, callback: BaseCallbackHandler) -> None: """Add a callback to the callback manager.""" with self._lock: diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index f16a287a366..ee72c8cc7ad 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text -from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult +from langchain.schema import AgentAction, LLMResult class StdOutCallbackHandler(BaseCallbackHandler): @@ -59,11 +59,16 @@ class StdOutCallbackHandler(BaseCallbackHandler): **kwargs: Any, ) -> None: """If not the final action, print out observation.""" - if output != AGENT_FINISH_OBSERVATION: - print_text(f"\n{observation_prefix}") - print_text(output, color=color) - print_text(f"\n{llm_prefix}") + print_text(f"\n{observation_prefix}") + print_text(output, color=color) + print_text(f"\n{llm_prefix}") def on_tool_error(self, error: Exception) -> None: """Do nothing.""" pass + + def on_agent_end( + self, log: str, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run when agent ends.""" + print_text(log, color=color) diff --git a/langchain/schema.py b/langchain/schema.py index ce2811fded4..8c6d8bbea60 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -1,13 +1,9 @@ """Common schema objects.""" -from dataclasses import dataclass from typing import List, NamedTuple, Optional -AGENT_FINISH_OBSERVATION = "__agent_finish__" - -@dataclass -class AgentAction: +class AgentAction(NamedTuple): """Agent's action to take.""" tool: str @@ -15,10 +11,10 @@ class AgentAction: log: str -@dataclass -class AgentFinish(AgentAction): +class AgentFinish(NamedTuple): """Agent's return value.""" + log: str return_values: dict diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 0bbae093b63..c48af48b75f 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -90,8 +90,9 @@ def test_agent_with_callbacks() -> None: output = agent.run("when was langchain made") assert output == "curses foiled again" - # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending - assert handler.starts == 7 + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.starts == 6 + # 1 extra agent end assert handler.ends == 7 assert handler.errors == 0 diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 5408042f5f1..18f88d1f2fd 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -58,3 +58,7 @@ class FakeCallbackHandler(BaseCallbackHandler): def on_tool_error(self, error: Exception) -> None: """Run when tool errors.""" self.errors += 1 + + def on_agent_end(self, log: str, **kwargs: Any) -> None: + """Run when agent is ending.""" + self.ends += 1