From 512c24fc9ceb6e6e464eac7f21dc07c4d5263024 Mon Sep 17 00:00:00 2001 From: Mike Wang <62768671+skcoirz@users.noreply.github.com> Date: Fri, 28 Apr 2023 21:17:28 -0700 Subject: [PATCH] [annotation improvement] Make AgentType->Class Conversion More Scalable (#3749) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the current solution, AgentType and AGENT_TO_CLASS are placed in two separate files and both manually maintained. This might cause inconsistency when we update either of them. — latest — based on the discussion with hwchase17, we don’t know how to further use the newly introduced AgentTypeConfig type, so it doesn’t make sense yet to add it. Instead, it’s better to move the dictionary to another file to keep the loading.py file clear. The consistency is a good point. Instead of asserting the consistency during linting, we added a unittest for consistency check. I think it works as auto unittest is triggered every time with clear failure notice. (well, force push is possible, but we all know what we are doing, so let’s show trust. :>) ~~This PR includes~~ - ~~Introduced AgentTypeConfig as the source of truth of all AgentType related meta data.~~ - ~~Each AgentTypeConfig is a annotated class type which can be used for annotation in other places.~~ - ~~Each AgentTypeConfig can be easily extended when we have more meta data needs.~~ - ~~Strong assertion to ensure AgentType and AGENT_TO_CLASS are always consistent.~~ - ~~Made AGENT_TO_CLASS automatically generated.~~ ~~Test Plan:~~ - ~~since this change is focusing on annotation, lint is the major test focus.~~ - ~~lint, format and test passed on local.~~ --- langchain/agents/loading.py | 19 ++----------------- langchain/agents/types.py | 19 +++++++++++++++++++ tests/unit_tests/agents/test_types.py | 9 +++++++++ 3 files changed, 30 insertions(+), 17 deletions(-) create mode 100644 langchain/agents/types.py create mode 100644 tests/unit_tests/agents/test_types.py diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index e071b783212..50d24d7d82e 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,32 +1,17 @@ """Functionality for loading agents.""" import json from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, List, Optional, Union import yaml from langchain.agents.agent import BaseSingleActionAgent -from langchain.agents.agent_types import AgentType -from langchain.agents.chat.base import ChatAgent -from langchain.agents.conversational.base import ConversationalAgent -from langchain.agents.conversational_chat.base import ConversationalChatAgent -from langchain.agents.mrkl.base import ZeroShotAgent -from langchain.agents.react.base import ReActDocstoreAgent -from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool +from langchain.agents.types import AGENT_TO_CLASS from langchain.chains.loading import load_chain, load_chain_from_config from langchain.llms.base import BaseLLM from langchain.utilities.loading import try_load_from_hub -AGENT_TO_CLASS: Dict[AgentType, Type[BaseSingleActionAgent]] = { - AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent, - AgentType.REACT_DOCSTORE: ReActDocstoreAgent, - AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent, - AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent, - AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent, - AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent, -} - URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" diff --git a/langchain/agents/types.py b/langchain/agents/types.py new file mode 100644 index 00000000000..c3d8b2a57fc --- /dev/null +++ b/langchain/agents/types.py @@ -0,0 +1,19 @@ +from typing import Dict, Type + +from langchain.agents.agent import BaseSingleActionAgent +from langchain.agents.agent_types import AgentType +from langchain.agents.chat.base import ChatAgent +from langchain.agents.conversational.base import ConversationalAgent +from langchain.agents.conversational_chat.base import ConversationalChatAgent +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.react.base import ReActDocstoreAgent +from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent + +AGENT_TO_CLASS: Dict[AgentType, Type[BaseSingleActionAgent]] = { + AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent, + AgentType.REACT_DOCSTORE: ReActDocstoreAgent, + AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent, + AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent, + AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent, + AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent, +} diff --git a/tests/unit_tests/agents/test_types.py b/tests/unit_tests/agents/test_types.py new file mode 100644 index 00000000000..536d1f1d177 --- /dev/null +++ b/tests/unit_tests/agents/test_types.py @@ -0,0 +1,9 @@ +import unittest + +from langchain.agents.agent_types import AgentType +from langchain.agents.types import AGENT_TO_CLASS + + +class TestTypes(unittest.TestCase): + def test_confirm_full_coverage(self) -> None: + self.assertEqual(list(AgentType), list(AGENT_TO_CLASS.keys()))