mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 18:33:40 +00:00
add explicit agent end method (#486)
This commit is contained in:
parent
7e36f28e78
commit
52490e2dcd
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ from langchain.llms.base import BaseLLM
|
|||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish
|
from langchain.schema import AgentAction, AgentFinish
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class Agent(BaseModel):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||||
) -> AgentAction:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -73,7 +73,7 @@ class Agent(BaseModel):
|
|||||||
parsed_output = self._extract_tool_and_input(full_output)
|
parsed_output = self._extract_tool_and_input(full_output)
|
||||||
tool, tool_input = parsed_output
|
tool, tool_input = parsed_output
|
||||||
if tool == self.finish_tool_name:
|
if tool == self.finish_tool_name:
|
||||||
return AgentFinish(tool, tool_input, full_output, {"output": tool_input})
|
return AgentFinish(full_output, {"output": tool_input})
|
||||||
return AgentAction(tool, tool_input, full_output)
|
return AgentAction(tool, tool_input, full_output)
|
||||||
|
|
||||||
def prepare_for_new_call(self) -> None:
|
def prepare_for_new_call(self) -> None:
|
||||||
@ -220,10 +220,7 @@ class AgentExecutor(Chain, BaseModel):
|
|||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self._get_callback_manager().on_tool_start(
|
self._get_callback_manager().on_agent_end(output.log, color="green")
|
||||||
{"name": "Finish"}, output, color="green"
|
|
||||||
)
|
|
||||||
self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION)
|
|
||||||
final_output = output.return_values
|
final_output = output.return_values
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
final_output["intermediate_steps"] = intermediate_steps
|
final_output["intermediate_steps"] = intermediate_steps
|
||||||
|
@ -54,6 +54,10 @@ class BaseCallbackHandler(ABC):
|
|||||||
def on_tool_error(self, error: Exception) -> None:
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent ends."""
|
||||||
|
|
||||||
|
|
||||||
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
class BaseCallbackManager(BaseCallbackHandler, ABC):
|
||||||
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
"""Base callback manager that can be used to handle callbacks from LangChain."""
|
||||||
@ -128,6 +132,11 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
for handler in self.handlers:
|
for handler in self.handlers:
|
||||||
handler.on_tool_error(error)
|
handler.on_tool_error(error)
|
||||||
|
|
||||||
|
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent ends."""
|
||||||
|
for handler in self.handlers:
|
||||||
|
handler.on_agent_end(log, **kwargs)
|
||||||
|
|
||||||
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
def add_handler(self, handler: BaseCallbackHandler) -> None:
|
||||||
"""Add a handler to the callback manager."""
|
"""Add a handler to the callback manager."""
|
||||||
self.handlers.append(handler)
|
self.handlers.append(handler)
|
||||||
|
@ -88,6 +88,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._callback_manager.on_tool_error(error)
|
self._callback_manager.on_tool_error(error)
|
||||||
|
|
||||||
|
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent ends."""
|
||||||
|
with self._lock:
|
||||||
|
self._callback_manager.on_agent_end(log, **kwargs)
|
||||||
|
|
||||||
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
def add_handler(self, callback: BaseCallbackHandler) -> None:
|
||||||
"""Add a callback to the callback manager."""
|
"""Add a callback to the callback manager."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.input import print_text
|
from langchain.input import print_text
|
||||||
from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult
|
from langchain.schema import AgentAction, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||||
@ -59,7 +59,6 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""If not the final action, print out observation."""
|
"""If not the final action, print out observation."""
|
||||||
if output != AGENT_FINISH_OBSERVATION:
|
|
||||||
print_text(f"\n{observation_prefix}")
|
print_text(f"\n{observation_prefix}")
|
||||||
print_text(output, color=color)
|
print_text(output, color=color)
|
||||||
print_text(f"\n{llm_prefix}")
|
print_text(f"\n{llm_prefix}")
|
||||||
@ -67,3 +66,9 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_tool_error(self, error: Exception) -> None:
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_agent_end(
|
||||||
|
self, log: str, color: Optional[str] = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Run when agent ends."""
|
||||||
|
print_text(log, color=color)
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
"""Common schema objects."""
|
"""Common schema objects."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
AGENT_FINISH_OBSERVATION = "__agent_finish__"
|
|
||||||
|
|
||||||
|
class AgentAction(NamedTuple):
|
||||||
@dataclass
|
|
||||||
class AgentAction:
|
|
||||||
"""Agent's action to take."""
|
"""Agent's action to take."""
|
||||||
|
|
||||||
tool: str
|
tool: str
|
||||||
@ -15,10 +11,10 @@ class AgentAction:
|
|||||||
log: str
|
log: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class AgentFinish(NamedTuple):
|
||||||
class AgentFinish(AgentAction):
|
|
||||||
"""Agent's return value."""
|
"""Agent's return value."""
|
||||||
|
|
||||||
|
log: str
|
||||||
return_values: dict
|
return_values: dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,8 +90,9 @@ def test_agent_with_callbacks() -> None:
|
|||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "curses foiled again"
|
assert output == "curses foiled again"
|
||||||
|
|
||||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending
|
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||||
assert handler.starts == 7
|
assert handler.starts == 6
|
||||||
|
# 1 extra agent end
|
||||||
assert handler.ends == 7
|
assert handler.ends == 7
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
@ -58,3 +58,7 @@ class FakeCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_tool_error(self, error: Exception) -> None:
|
def on_tool_error(self, error: Exception) -> None:
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_agent_end(self, log: str, **kwargs: Any) -> None:
|
||||||
|
"""Run when agent is ending."""
|
||||||
|
self.ends += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user