mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
Move Tool Validation (#3923)
Move tool validation to each implementation of the Agent. Another alternative would be to adjust the `_validate_tools()` signature to accept the output parser (and format instructions) and add logic there. Something like `parser.outputs_structured_actions(format_instructions)` But don't think that's needed right now.
This commit is contained in:
parent
7cce68a051
commit
84ea17b786
@ -497,11 +497,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
"""Validate that appropriate tools are passed in."""
|
"""Validate that appropriate tools are passed in."""
|
||||||
for tool in tools:
|
pass
|
||||||
if not tool.is_single_input:
|
|
||||||
raise ValueError(
|
|
||||||
f"{cls.__name__} does not support multi-input tool {tool.name}."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -5,6 +5,7 @@ from pydantic import Field
|
|||||||
from langchain.agents.agent import Agent, AgentOutputParser
|
from langchain.agents.agent import Agent, AgentOutputParser
|
||||||
from langchain.agents.chat.output_parser import ChatOutputParser
|
from langchain.agents.chat.output_parser import ChatOutputParser
|
||||||
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||||
|
from langchain.agents.utils import validate_tools_single_input
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
@ -15,7 +16,7 @@ from langchain.prompts.chat import (
|
|||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema import AgentAction
|
from langchain.schema import AgentAction
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(Agent):
|
class ChatAgent(Agent):
|
||||||
@ -50,6 +51,11 @@ class ChatAgent(Agent):
|
|||||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||||
return ChatOutputParser()
|
return ChatOutputParser()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
super()._validate_tools(tools)
|
||||||
|
validate_tools_single_input(class_name=cls.__name__, tools=tools)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> List[str]:
|
||||||
return ["Observation:"]
|
return ["Observation:"]
|
||||||
|
@ -9,6 +9,7 @@ from langchain.agents.agent import Agent, AgentOutputParser
|
|||||||
from langchain.agents.agent_types import AgentType
|
from langchain.agents.agent_types import AgentType
|
||||||
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
from langchain.agents.conversational.output_parser import ConvoOutputParser
|
||||||
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||||
|
from langchain.agents.utils import validate_tools_single_input
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
@ -80,6 +81,11 @@ class ConversationalAgent(Agent):
|
|||||||
input_variables = ["input", "chat_history", "agent_scratchpad"]
|
input_variables = ["input", "chat_history", "agent_scratchpad"]
|
||||||
return PromptTemplate(template=template, input_variables=input_variables)
|
return PromptTemplate(template=template, input_variables=input_variables)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
super()._validate_tools(tools)
|
||||||
|
validate_tools_single_input(cls.__name__, tools)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
|
@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
|
|||||||
SUFFIX,
|
SUFFIX,
|
||||||
TEMPLATE_TOOL_RESPONSE,
|
TEMPLATE_TOOL_RESPONSE,
|
||||||
)
|
)
|
||||||
|
from langchain.agents.utils import validate_tools_single_input
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
@ -55,6 +56,11 @@ class ConversationalChatAgent(Agent):
|
|||||||
"""Prefix to append the llm call with."""
|
"""Prefix to append the llm call with."""
|
||||||
return "Thought:"
|
return "Thought:"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
super()._validate_tools(tools)
|
||||||
|
validate_tools_single_input(cls.__name__, tools)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
cls,
|
cls,
|
||||||
|
@ -10,6 +10,7 @@ from langchain.agents.agent_types import AgentType
|
|||||||
from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
from langchain.agents.mrkl.output_parser import MRKLOutputParser
|
||||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.agents.utils import validate_tools_single_input
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
@ -122,13 +123,14 @@ class ZeroShotAgent(Agent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
super()._validate_tools(tools)
|
validate_tools_single_input(cls.__name__, tools)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.description is None:
|
if tool.description is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got a tool {tool.name} without a description. For this agent, "
|
f"Got a tool {tool.name} without a description. For this agent, "
|
||||||
f"a description must always be provided."
|
f"a description must always be provided."
|
||||||
)
|
)
|
||||||
|
super()._validate_tools(tools)
|
||||||
|
|
||||||
|
|
||||||
class MRKLChain(AgentExecutor):
|
class MRKLChain(AgentExecutor):
|
||||||
|
@ -9,6 +9,7 @@ from langchain.agents.react.output_parser import ReActOutputParser
|
|||||||
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
||||||
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
|
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.agents.utils import validate_tools_single_input
|
||||||
from langchain.docstore.base import Docstore
|
from langchain.docstore.base import Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
@ -37,6 +38,7 @@ class ReActDocstoreAgent(Agent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
validate_tools_single_input(cls.__name__, tools)
|
||||||
super()._validate_tools(tools)
|
super()._validate_tools(tools)
|
||||||
if len(tools) != 2:
|
if len(tools) != 2:
|
||||||
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
raise ValueError(f"Exactly two tools must be specified, but got {tools}")
|
||||||
@ -120,6 +122,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
validate_tools_single_input(cls.__name__, tools)
|
||||||
super()._validate_tools(tools)
|
super()._validate_tools(tools)
|
||||||
if len(tools) != 1:
|
if len(tools) != 1:
|
||||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||||
|
@ -8,6 +8,7 @@ from langchain.agents.agent_types import AgentType
|
|||||||
from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser
|
from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser
|
||||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.agents.utils import validate_tools_single_input
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
@ -36,6 +37,7 @@ class SelfAskWithSearchAgent(Agent):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||||
|
validate_tools_single_input(cls.__name__, tools)
|
||||||
super()._validate_tools(tools)
|
super()._validate_tools(tools)
|
||||||
if len(tools) != 1:
|
if len(tools) != 1:
|
||||||
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
||||||
|
12
langchain/agents/utils.py
Normal file
12
langchain/agents/utils.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tools_single_input(class_name: str, tools: Sequence[BaseTool]) -> None:
|
||||||
|
"""Validate tools for single input."""
|
||||||
|
for tool in tools:
|
||||||
|
if not tool.is_single_input:
|
||||||
|
raise ValueError(
|
||||||
|
f"{class_name} does not support multi-input tool {tool.name}."
|
||||||
|
)
|
@ -400,8 +400,8 @@ async def test_create_async_tool() -> None:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"agent_cls",
|
"agent_cls",
|
||||||
[
|
[
|
||||||
ChatAgent,
|
|
||||||
ZeroShotAgent,
|
ZeroShotAgent,
|
||||||
|
ChatAgent,
|
||||||
ConversationalChatAgent,
|
ConversationalChatAgent,
|
||||||
ConversationalAgent,
|
ConversationalAgent,
|
||||||
ReActDocstoreAgent,
|
ReActDocstoreAgent,
|
||||||
|
Loading…
Reference in New Issue
Block a user