mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
add explicit agent end method (#486)
This commit is contained in:
parent
7e36f28e78
commit
52490e2dcd
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
@ -16,7 +16,7 @@ from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@ -46,7 +46,7 @@ class Agent(BaseModel):
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> AgentAction:
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
@ -73,7 +73,7 @@ class Agent(BaseModel):
|
||||
parsed_output = self._extract_tool_and_input(full_output)
|
||||
tool, tool_input = parsed_output
|
||||
if tool == self.finish_tool_name:
|
||||
return AgentFinish(tool, tool_input, full_output, {"output": tool_input})
|
||||
return AgentFinish(full_output, {"output": tool_input})
|
||||
return AgentAction(tool, tool_input, full_output)
|
||||
|
||||
def prepare_for_new_call(self) -> None:
|
||||
@ -220,10 +220,7 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
if self.verbose:
|
||||
self._get_callback_manager().on_tool_start(
|
||||
{"name": "Finish"}, output, color="green"
|
||||
)
|
||||
self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION)
|
||||
self._get_callback_manager().on_agent_end(output.log, color="green")
|
||||
final_output = output.return_values
|
||||
if self.return_intermediate_steps:
|
||||
final_output["intermediate_steps"] = intermediate_steps
|
||||
|
@ -54,6 +54,10 @@ class BaseCallbackHandler(ABC):
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
@abstractmethod
|
||||
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||
"""Run when agent ends."""
|
||||
|
||||
|
||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||
@ -128,6 +132,11 @@ class CallbackManager(BaseCallbackManager):
|
||||
for handler in self.handlers:
|
||||
handler.on_tool_error(error)
|
||||
|
||||
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||
"""Run when agent ends."""
|
||||
for handler in self.handlers:
|
||||
handler.on_agent_end(log, **kwargs)
|
||||
|
||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||
"""Add a handler to the callback manager."""
|
||||
self.handlers.append(handler)
|
||||
|
@ -88,6 +88,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
||||
with self._lock:
|
||||
self._callback_manager.on_tool_error(error)
|
||||
|
||||
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||
"""Run when agent ends."""
|
||||
with self._lock:
|
||||
self._callback_manager.on_agent_end(log, **kwargs)
|
||||
|
||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||
"""Add a callback to the callback manager."""
|
||||
with self._lock:
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult
|
||||
from langchain.schema import AgentAction, LLMResult
|
||||
|
||||
|
||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
@ -59,11 +59,16 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
if output != AGENT_FINISH_OBSERVATION:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color)
|
||||
print_text(f"\n{llm_prefix}")
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color=color)
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_end(
|
||||
self, log: str, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text(log, color=color)
|
||||
|
@ -1,13 +1,9 @@
|
||||
"""Common schema objects."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
AGENT_FINISH_OBSERVATION = "__agent_finish__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
class AgentAction(NamedTuple):
|
||||
"""Agent's action to take."""
|
||||
|
||||
tool: str
|
||||
@ -15,10 +11,10 @@ class AgentAction:
|
||||
log: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentFinish(AgentAction):
|
||||
class AgentFinish(NamedTuple):
|
||||
"""Agent's return value."""
|
||||
|
||||
log: str
|
||||
return_values: dict
|
||||
|
||||
|
||||
|
@ -90,8 +90,9 @@ def test_agent_with_callbacks() -> None:
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending
|
||||
assert handler.starts == 7
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 6
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
|
||||
|
@ -58,3 +58,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.ends += 1
|
||||
|
Loading…
Reference in New Issue
Block a user