From ec21b7126c225693a1fb28d9c60380515df72f09 Mon Sep 17 00:00:00 2001 From: Mike Wang <62768671+skcoirz@users.noreply.github.com> Date: Mon, 1 May 2023 20:30:10 -0700 Subject: [PATCH] =?UTF-8?q?[agent][property=20type]=20Change=20allowed=5Ft?= =?UTF-8?q?ools=20to=20Set=20as=20Duplicate=20doesn=E2=80=99t=20make=20sen?= =?UTF-8?q?se=20(#3840)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ActionAgent has a property called, `allowed_tools`, which is declared as `List`. It stores all provided tools which is available to use during agent action. - This collection shouldn’t allow duplicates. The original datatype List doesn’t make sense. Each tool should be unique. Even when there are variants (assuming in the future), it would be named differently in load_tools. Test: - confirm the functionality in an example by initializing an agent with a list of 2 tools and confirm everything works. ```python3 def test_agent_chain_chat_bot(): from langchain.agents import load_tools from langchain.agents import initialize_agent from langchain.agents import AgentType from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper chat = ChatOpenAI(temperature=0) llm = OpenAI(temperature=0) tools = load_tools(["ddg-search", "llm-math"], llm=llm) agent = initialize_agent(tools, chat, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True) agent.run("Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?") test_agent_chain_chat_bot() ``` Result: Screenshot 2023-05-01 at 7 58 11 PM --- langchain/agents/agent.py | 25 ++++++++++----------- tests/integration_tests/agent/__init__.py | 1 + tests/integration_tests/agent/test_agent.py | 16 +++++++++++++ 3 files changed, 29 insertions(+), 13 deletions(-) create mode 100644 tests/integration_tests/agent/__init__.py create mode 100644 tests/integration_tests/agent/test_agent.py diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index eb9b859fab3..ca382031bb2 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -7,7 +7,7 @@ import logging import time from abc import abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union import yaml from pydantic import BaseModel, root_validator @@ -46,8 +46,8 @@ class BaseSingleActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Optional[List[str]]: - return None + def get_allowed_tools(self) -> Set[str]: + return set() @abstractmethod def plan( @@ -178,8 +178,8 @@ class BaseMultiActionAgent(BaseModel): """Return values of the agent.""" return ["output"] - def get_allowed_tools(self) -> Optional[List[str]]: - return None + def get_allowed_tools(self) -> Set[str]: + return set() @abstractmethod def plan( @@ -372,9 +372,9 @@ class Agent(BaseSingleActionAgent): llm_chain: LLMChain output_parser: AgentOutputParser - allowed_tools: Optional[List[str]] = None + allowed_tools: Set[str] = set() - def get_allowed_tools(self) -> Optional[List[str]]: + def get_allowed_tools(self) -> Set[str]: return self.allowed_tools @property @@ -607,12 +607,11 @@ class AgentExecutor(Chain): agent = values["agent"] tools = values["tools"] allowed_tools = agent.get_allowed_tools() - if allowed_tools is not None: - if set(allowed_tools) != set([tool.name for tool in tools]): - raise ValueError( - f"Allowed tools ({allowed_tools}) different than " - f"provided tools ({[tool.name for tool in tools]})" - ) + if allowed_tools != set([tool.name for tool in tools]): + raise ValueError( + f"Allowed tools ({allowed_tools}) different than " + f"provided tools ({[tool.name for tool in tools]})" + ) return values @root_validator() diff --git a/tests/integration_tests/agent/__init__.py b/tests/integration_tests/agent/__init__.py new file mode 100644 index 00000000000..117480e1e80 --- /dev/null +++ b/tests/integration_tests/agent/__init__.py @@ -0,0 +1 @@ +"""All integration tests for agent.""" diff --git a/tests/integration_tests/agent/test_agent.py b/tests/integration_tests/agent/test_agent.py new file mode 100644 index 00000000000..b423414edd8 --- /dev/null +++ b/tests/integration_tests/agent/test_agent.py @@ -0,0 +1,16 @@ +from langchain.agents.chat.base import ChatAgent +from langchain.llms.openai import OpenAI +from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun + + +class TestAgent: + def test_agent_generation(self) -> None: + web_search = DuckDuckGoSearchRun() + tools = [web_search] + agent = ChatAgent.from_llm_and_tools( + ai_name="Tom", + ai_role="Assistant", + tools=tools, + llm=OpenAI(maxTokens=10), + ) + assert agent.allowed_tools == set([web_search.name])