diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index e071b783212..50d24d7d82e 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,32 +1,17 @@ """Functionality for loading agents.""" import json from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, List, Optional, Union import yaml from langchain.agents.agent import BaseSingleActionAgent -from langchain.agents.agent_types import AgentType -from langchain.agents.chat.base import ChatAgent -from langchain.agents.conversational.base import ConversationalAgent -from langchain.agents.conversational_chat.base import ConversationalChatAgent -from langchain.agents.mrkl.base import ZeroShotAgent -from langchain.agents.react.base import ReActDocstoreAgent -from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool +from langchain.agents.types import AGENT_TO_CLASS from langchain.chains.loading import load_chain, load_chain_from_config from langchain.llms.base import BaseLLM from langchain.utilities.loading import try_load_from_hub -AGENT_TO_CLASS: Dict[AgentType, Type[BaseSingleActionAgent]] = { - AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent, - AgentType.REACT_DOCSTORE: ReActDocstoreAgent, - AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent, - AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent, - AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent, - AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent, -} - URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" diff --git a/langchain/agents/types.py b/langchain/agents/types.py new file mode 100644 index 00000000000..c3d8b2a57fc --- /dev/null +++ b/langchain/agents/types.py @@ -0,0 +1,19 @@ +from typing import Dict, Type + +from langchain.agents.agent import BaseSingleActionAgent +from langchain.agents.agent_types import AgentType +from langchain.agents.chat.base import ChatAgent +from langchain.agents.conversational.base import ConversationalAgent +from langchain.agents.conversational_chat.base import ConversationalChatAgent +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.react.base import ReActDocstoreAgent +from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent + +AGENT_TO_CLASS: Dict[AgentType, Type[BaseSingleActionAgent]] = { + AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent, + AgentType.REACT_DOCSTORE: ReActDocstoreAgent, + AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent, + AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent, + AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent, + AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent, +} diff --git a/tests/unit_tests/agents/test_types.py b/tests/unit_tests/agents/test_types.py new file mode 100644 index 00000000000..536d1f1d177 --- /dev/null +++ b/tests/unit_tests/agents/test_types.py @@ -0,0 +1,9 @@ +import unittest + +from langchain.agents.agent_types import AgentType +from langchain.agents.types import AGENT_TO_CLASS + + +class TestTypes(unittest.TestCase): + def test_confirm_full_coverage(self) -> None: + self.assertEqual(list(AgentType), list(AGENT_TO_CLASS.keys()))