Revert "[agent][property type] Change allowed_tools to Set as Duplicate doesn’t make sense" (#4014)

Reverts hwchase17/langchain#3840
This commit is contained in:
Harrison Chase 2023-05-02 18:58:05 -07:00 committed by GitHub
parent df3bc707fc
commit a5dd73c1a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 29 deletions

View File

@ -7,7 +7,7 @@ import logging
import time import time
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import yaml import yaml
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -48,8 +48,8 @@ class BaseSingleActionAgent(BaseModel):
"""Return values of the agent.""" """Return values of the agent."""
return ["output"] return ["output"]
def get_allowed_tools(self) -> Set[str]: def get_allowed_tools(self) -> Optional[List[str]]:
return set() return None
@abstractmethod @abstractmethod
def plan( def plan(
@ -180,8 +180,8 @@ class BaseMultiActionAgent(BaseModel):
"""Return values of the agent.""" """Return values of the agent."""
return ["output"] return ["output"]
def get_allowed_tools(self) -> Set[str]: def get_allowed_tools(self) -> Optional[List[str]]:
return set() return None
@abstractmethod @abstractmethod
def plan( def plan(
@ -374,9 +374,9 @@ class Agent(BaseSingleActionAgent):
llm_chain: LLMChain llm_chain: LLMChain
output_parser: AgentOutputParser output_parser: AgentOutputParser
allowed_tools: Set[str] = set() allowed_tools: Optional[List[str]] = None
def get_allowed_tools(self) -> Set[str]: def get_allowed_tools(self) -> Optional[List[str]]:
return self.allowed_tools return self.allowed_tools
@property @property
@ -629,11 +629,12 @@ class AgentExecutor(Chain):
agent = values["agent"] agent = values["agent"]
tools = values["tools"] tools = values["tools"]
allowed_tools = agent.get_allowed_tools() allowed_tools = agent.get_allowed_tools()
if allowed_tools != set([tool.name for tool in tools]): if allowed_tools is not None:
raise ValueError( if set(allowed_tools) != set([tool.name for tool in tools]):
f"Allowed tools ({allowed_tools}) different than " raise ValueError(
f"provided tools ({[tool.name for tool in tools]})" f"Allowed tools ({allowed_tools}) different than "
) f"provided tools ({[tool.name for tool in tools]})"
)
return values return values
@root_validator() @root_validator()

View File

@ -1 +0,0 @@
"""All integration tests for agent."""

View File

@ -1,16 +0,0 @@
from langchain.agents.chat.base import ChatAgent
from langchain.llms.openai import OpenAI
from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun
class TestAgent:
def test_agent_generation(self) -> None:
web_search = DuckDuckGoSearchRun()
tools = [web_search]
agent = ChatAgent.from_llm_and_tools(
ai_name="Tom",
ai_role="Assistant",
tools=tools,
llm=OpenAI(maxTokens=10),
)
assert agent.allowed_tools == set([web_search.name])