mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 19:49:09 +00:00
Move Tool Validation (#3923)
Move tool validation to each implementation of the Agent. Another alternative would be to adjust the `_validate_tools()` signature to accept the output parser (and format instructions) and add logic there. Something like `parser.outputs_structured_actions(format_instructions)` But don't think that's needed right now.
This commit is contained in:
parent
7cce68a051
commit
84ea17b786
@ -497,11 +497,7 @@ class Agent(BaseSingleActionAgent):
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
"""Validate that appropriate tools are passed in."""
|
||||
for tool in tools:
|
||||
if not tool.is_single_input:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support multi-input tool {tool.name}."
|
||||
)
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
@ -5,6 +5,7 @@ from pydantic import Field
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.chat.output_parser import ChatOutputParser
|
||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -15,7 +16,7 @@ from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.schema import AgentAction
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
class ChatAgent(Agent):
|
||||
@ -50,6 +51,11 @@ class ChatAgent(Agent):
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
return ChatOutputParser()
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
validate_tools_single_input(class_name=cls.__name__, tools=tools)
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return ["Observation:"]
|
||||
|
@ -9,6 +9,7 @@ from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
@ -80,6 +81,11 @@ class ConversationalAgent(Agent):
|
||||
input_variables = ["input", "chat_history", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
|
@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
|
||||
SUFFIX,
|
||||
TEMPLATE_TOOL_RESPONSE,
|
||||
)
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
@ -55,6 +56,11 @@ class ConversationalChatAgent(Agent):
|
||||
"""Prefix to append the llm call with."""
|
||||
return "Thought:"
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
|
@ -10,6 +10,7 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains import LLMChain
|
||||
@ -122,13 +123,14 @@ class ZeroShotAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
super()._validate_tools(tools)
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
for tool in tools:
|
||||
if tool.description is None:
|
||||
raise ValueError(
|
||||
f"Got a tool {tool.name} without a description. For this agent, "
|
||||
f"a description must always be provided."
|
||||
)
|
||||
super()._validate_tools(tools)
|
||||
|
||||
|
||||
class MRKLChain(AgentExecutor):
|
||||
|
@ -9,6 +9,7 @@ from langchain.agents.react.output_parser import ReActOutputParser
|
||||
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
||||
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import BaseLLM
|
||||
@ -37,6 +38,7 @@ class ReActDocstoreAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 2:
|
||||
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
||||
@ -120,6 +122,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
|
@ -8,6 +8,7 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser
|
||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.agents.utils import validate_tools_single_input
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.tools.base import BaseTool
|
||||
@ -36,6 +37,7 @@ class SelfAskWithSearchAgent(Agent):
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
validate_tools_single_input(cls.__name__, tools)
|
||||
super()._validate_tools(tools)
|
||||
if len(tools) != 1:
|
||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||
|
12
langchain/agents/utils.py
Normal file
12
langchain/agents/utils.py
Normal file
@ -0,0 +1,12 @@
|
||||
from typing import Sequence
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None:
|
||||
"""Validate tools for single input."""
|
||||
for tool in tools:
|
||||
if not tool.is_single_input:
|
||||
raise ValueError(
|
||||
f"{class_name} does not support multi-input tool {tool.name}."
|
||||
)
|
@ -400,8 +400,8 @@ async def test_create_async_tool() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"agent_cls",
|
||||
[
|
||||
ChatAgent,
|
||||
ZeroShotAgent,
|
||||
ChatAgent,
|
||||
ConversationalChatAgent,
|
||||
ConversationalAgent,
|
||||
ReActDocstoreAgent,
|
||||
|
Loading…
Reference in New Issue
Block a user