From a5d003f0c99d5c6d6f2f80a174c45207663b282a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 28 Jan 2023 07:23:04 -0800 Subject: [PATCH] update notebook and make backwards compatible (#772) --- .../agents/examples/custom_agent.ipynb | 35 ++++++++++--------- langchain/agents/agent.py | 13 +++---- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/docs/modules/agents/examples/custom_agent.ipynb b/docs/modules/agents/examples/custom_agent.ipynb index 780f2b28055..579492caec5 100644 --- a/docs/modules/agents/examples/custom_agent.ipynb +++ b/docs/modules/agents/examples/custom_agent.ipynb @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "becda2a1", "metadata": {}, "outputs": [], @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "id": "339b1bb8", "metadata": {}, "outputs": [], @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "id": "e21d2098", "metadata": {}, "outputs": [ @@ -134,7 +134,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "5e028e6d", "metadata": {}, @@ -146,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "id": "9b1cc2a2", "metadata": {}, "outputs": [], @@ -156,17 +155,18 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "id": "e4f5092f", "metadata": {}, "outputs": [], "source": [ - "agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)" + "tool_names = [tool.name for tool in tools]\n", + "agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 8, "id": "490604e9", "metadata": {}, "outputs": [], @@ -176,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "id": "653b1617", "metadata": {}, "outputs": [ @@ -191,22 +191,23 @@ "Action: Search\n", "Action Input: Population of Canada\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mCanada is a country in North America. Its ten provinces and three territories extend from the Atlantic Ocean to the Pacific Ocean and northward into the Arctic Ocean, covering over 9.98 million square kilometres, making it the world's second-largest country by total area.\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out the exact population of Canada\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the population of Canada\n", "Action: Search\n", - "Action Input: Population of Canada 2020\u001b[0m\n", + "Action Input: Population of Canada\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mCanada is a country in North America. Its ten provinces and three territories extend from the Atlantic Ocean to the Pacific Ocean and northward into the Arctic Ocean, covering over 9.98 million square kilometres, making it the world's second-largest country by total area.\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the population of Canada\n", - "Final Answer: Arrr, Canada be home to 37.59 million people!\u001b[0m\n", - "\u001b[1m> Finished AgentExecutor chain.\u001b[0m\n" + "Final Answer: Arrr, Canada be home to over 37 million people!\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "'Arrr, Canada be home to 37.59 million people!'" + "'Arrr, Canada be home to over 37 million people!'" ] }, - "execution_count": 19, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -361,7 +362,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -375,7 +376,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12 (default, Feb 15 2022, 17:41:09) \n[Clang 12.0.5 (clang-1205.0.22.11)]" + "version": "3.10.9" }, "vscode": { "interpreter": { diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 61645219f21..f520c562486 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -33,7 +33,7 @@ class Agent(BaseModel): """ llm_chain: LLMChain - allowed_tools: List[str] + allowed_tools: Optional[List[str]] = None return_values: List[str] = ["output"] @abstractmethod @@ -269,11 +269,12 @@ class AgentExecutor(Chain, BaseModel): """Validate that tools are compatible with agent.""" agent = values["agent"] tools = values["tools"] - if set(agent.allowed_tools) != set([tool.name for tool in tools]): - raise ValueError( - f"Allowed tools ({agent.allowed_tools}) different than " - f"provided tools ({[tool.name for tool in tools]})" - ) + if agent.allowed_tools is not None: + if set(agent.allowed_tools) != set([tool.name for tool in tools]): + raise ValueError( + f"Allowed tools ({agent.allowed_tools}) different than " + f"provided tools ({[tool.name for tool in tools]})" + ) return values def save(self, file_path: Union[Path, str]) -> None: