mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 00:49:25 +00:00
Revert "[agent][property type] Change allowed_tools to Set as Duplicate doesn’t make sense" (#4014)
Reverts hwchase17/langchain#3840
This commit is contained in:
parent
df3bc707fc
commit
a5dd73c1a6
@ -7,7 +7,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
@ -48,8 +48,8 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return ["output"]
|
return ["output"]
|
||||||
|
|
||||||
def get_allowed_tools(self) -> Set[str]:
|
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||||
return set()
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def plan(
|
def plan(
|
||||||
@ -180,8 +180,8 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return ["output"]
|
return ["output"]
|
||||||
|
|
||||||
def get_allowed_tools(self) -> Set[str]:
|
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||||
return set()
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def plan(
|
def plan(
|
||||||
@ -374,9 +374,9 @@ class Agent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
output_parser: AgentOutputParser
|
output_parser: AgentOutputParser
|
||||||
allowed_tools: Set[str] = set()
|
allowed_tools: Optional[List[str]] = None
|
||||||
|
|
||||||
def get_allowed_tools(self) -> Set[str]:
|
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||||
return self.allowed_tools
|
return self.allowed_tools
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -629,11 +629,12 @@ class AgentExecutor(Chain):
|
|||||||
agent = values["agent"]
|
agent = values["agent"]
|
||||||
tools = values["tools"]
|
tools = values["tools"]
|
||||||
allowed_tools = agent.get_allowed_tools()
|
allowed_tools = agent.get_allowed_tools()
|
||||||
if allowed_tools != set([tool.name for tool in tools]):
|
if allowed_tools is not None:
|
||||||
raise ValueError(
|
if set(allowed_tools) != set([tool.name for tool in tools]):
|
||||||
f"Allowed tools ({allowed_tools}) different than "
|
raise ValueError(
|
||||||
f"provided tools ({[tool.name for tool in tools]})"
|
f"Allowed tools ({allowed_tools}) different than "
|
||||||
)
|
f"provided tools ({[tool.name for tool in tools]})"
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
|
@ -1 +0,0 @@
|
|||||||
"""All integration tests for agent."""
|
|
@ -1,16 +0,0 @@
|
|||||||
from langchain.agents.chat.base import ChatAgent
|
|
||||||
from langchain.llms.openai import OpenAI
|
|
||||||
from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgent:
|
|
||||||
def test_agent_generation(self) -> None:
|
|
||||||
web_search = DuckDuckGoSearchRun()
|
|
||||||
tools = [web_search]
|
|
||||||
agent = ChatAgent.from_llm_and_tools(
|
|
||||||
ai_name="Tom",
|
|
||||||
ai_role="Assistant",
|
|
||||||
tools=tools,
|
|
||||||
llm=OpenAI(maxTokens=10),
|
|
||||||
)
|
|
||||||
assert agent.allowed_tools == set([web_search.name])
|
|
Loading…
Reference in New Issue
Block a user