From 1b89a438cf83b9fe1e77393956a8a977d616add4 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 26 Jan 2023 19:48:47 -0800 Subject: [PATCH] (wip) Harrison/serialize agents (#725) --- .../agents/examples/serialization.ipynb | 148 ++++++++++++++++++ langchain/agents/__init__.py | 4 +- langchain/agents/agent.py | 75 ++++++++- langchain/agents/conversational/base.py | 10 +- langchain/agents/initialize.py | 68 ++++++++ langchain/agents/loading.py | 87 +++++----- langchain/agents/mrkl/base.py | 5 + langchain/agents/react/base.py | 5 + langchain/agents/self_ask_with_search/base.py | 5 + 9 files changed, 361 insertions(+), 46 deletions(-) create mode 100644 docs/modules/agents/examples/serialization.ipynb create mode 100644 langchain/agents/initialize.py diff --git a/docs/modules/agents/examples/serialization.ipynb b/docs/modules/agents/examples/serialization.ipynb new file mode 100644 index 00000000000..5e4ad0d028c --- /dev/null +++ b/docs/modules/agents/examples/serialization.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bfe18e28", + "metadata": {}, + "source": [ + "# Serialization\n", + "\n", + "This notebook goes over how to serialize agents. For this notebook, it is important to understand the distinction we draw between `agents` and `tools`. An agent is the LLM powered decision maker that decides which actions to take and in which order. Tools are various instruments (functions) an agent has access to, through which an agent can interact with the outside world. When people generally use agents, they primarily talk about using an agent WITH tools. However, when we talk about serialization of agents, we are talking about the agent by itself. We plan to add support for serializing an agent WITH tools sometime in the future.\n", + "\n", + "Let's start by creating an agent with tools as we normally do:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "eb729f16", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import load_tools\n", + "from langchain.agents import initialize_agent\n", + "from langchain.llms import OpenAI\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n", + "agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "0578f566", + "metadata": {}, + "source": [ + "Let's now serialize the agent. To be explicit that we are serializing ONLY the agent, we will call the `save_agent` method." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dc544de6", + "metadata": {}, + "outputs": [], + "source": [ + "agent.save_agent('agent.json')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "62dd45bf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\r\n", + " \"llm_chain\": {\r\n", + " \"memory\": null,\r\n", + " \"verbose\": false,\r\n", + " \"prompt\": {\r\n", + " \"input_variables\": [\r\n", + " \"input\",\r\n", + " \"agent_scratchpad\"\r\n", + " ],\r\n", + " \"output_parser\": null,\r\n", + " \"template\": \"Answer the following questions as best you can. You have access to the following tools:\\n\\nSearch: A search engine. Useful for when you need to answer questions about current events. Input should be a search query.\\nCalculator: Useful for when you need to answer questions about math.\\n\\nUse the following format:\\n\\nQuestion: the input question you must answer\\nThought: you should always think about what to do\\nAction: the action to take, should be one of [Search, Calculator]\\nAction Input: the input to the action\\nObservation: the result of the action\\n... (this Thought/Action/Action Input/Observation can repeat N times)\\nThought: I now know the final answer\\nFinal Answer: the final answer to the original input question\\n\\nBegin!\\n\\nQuestion: {input}\\nThought:{agent_scratchpad}\",\r\n", + " \"template_format\": \"f-string\"\r\n", + " },\r\n", + " \"llm\": {\r\n", + " \"model_name\": \"text-davinci-003\",\r\n", + " \"temperature\": 0.0,\r\n", + " \"max_tokens\": 256,\r\n", + " \"top_p\": 1,\r\n", + " \"frequency_penalty\": 0,\r\n", + " \"presence_penalty\": 0,\r\n", + " \"n\": 1,\r\n", + " \"best_of\": 1,\r\n", + " \"request_timeout\": null,\r\n", + " \"logit_bias\": {},\r\n", + " \"_type\": \"openai\"\r\n", + " },\r\n", + " \"output_key\": \"text\",\r\n", + " \"_type\": \"llm_chain\"\r\n", + " },\r\n", + " \"return_values\": [\r\n", + " \"output\"\r\n", + " ],\r\n", + " \"_type\": \"zero-shot-react-description\"\r\n", + "}" + ] + } + ], + "source": [ + "!cat agent.json" + ] + }, + { + "cell_type": "markdown", + "id": "0eb72510", + "metadata": {}, + "source": [ + "We can now load the agent back in" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "eb660b76", + "metadata": {}, + "outputs": [], + "source": [ + "agent = initialize_agent(tools, llm, agent_path=\"agent.json\", verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa624ea5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index 1e8ad771e96..b25d17a13c8 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -1,8 +1,9 @@ """Interface for agents.""" from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.conversational.base import ConversationalAgent +from langchain.agents.initialize import initialize_agent from langchain.agents.load_tools import get_all_tool_names, load_tools -from langchain.agents.loading import initialize_agent +from langchain.agents.loading import load_agent from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent from langchain.agents.react.base import ReActChain, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain @@ -21,4 +22,5 @@ __all__ = [ "load_tools", "get_all_tool_names", "ConversationalAgent", + "load_agent", ] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 22e7667b6ba..61645219f21 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -1,10 +1,13 @@ """Chain that takes in an input and produces an action and action input.""" from __future__ import annotations +import json import logging from abc import abstractmethod +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union +import yaml from pydantic import BaseModel, root_validator from langchain.agents.tools import Tool @@ -30,6 +33,7 @@ class Agent(BaseModel): """ llm_chain: LLMChain + allowed_tools: List[str] return_values: List[str] = ["output"] @abstractmethod @@ -146,7 +150,8 @@ class Agent(BaseModel): prompt=cls.create_prompt(tools), callback_manager=callback_manager, ) - return cls(llm_chain=llm_chain, **kwargs) + tool_names = [tool.name for tool in tools] + return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) def return_stopped_response( self, @@ -192,6 +197,50 @@ class Agent(BaseModel): f"got {early_stopping_method}" ) + @property + @abstractmethod + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + _dict["_type"] = self._agent_type + return _dict + + def save(self, file_path: Union[Path, str]) -> None: + """Save the agent. + + Args: + file_path: Path to file to save the agent to. + + Example: + .. code-block:: python + + # If working with agent executor + agent.agent.save(file_path="path/agent.yaml") + """ + # Convert file to Path object. + if isinstance(file_path, str): + save_path = Path(file_path) + else: + save_path = file_path + + directory_path = save_path.parent + directory_path.mkdir(parents=True, exist_ok=True) + + # Fetch dictionary to save + agent_dict = self.dict() + + if save_path.suffix == ".json": + with open(file_path, "w") as f: + json.dump(agent_dict, f, indent=4) + elif save_path.suffix == ".yaml": + with open(file_path, "w") as f: + yaml.dump(agent_dict, f, default_flow_style=False) + else: + raise ValueError(f"{save_path} must be json or yaml") + class AgentExecutor(Chain, BaseModel): """Consists of an agent using tools.""" @@ -215,6 +264,30 @@ class AgentExecutor(Chain, BaseModel): agent=agent, tools=tools, callback_manager=callback_manager, **kwargs ) + @root_validator() + def validate_tools(cls, values: Dict) -> Dict: + """Validate that tools are compatible with agent.""" + agent = values["agent"] + tools = values["tools"] + if set(agent.allowed_tools) != set([tool.name for tool in tools]): + raise ValueError( + f"Allowed tools ({agent.allowed_tools}) different than " + f"provided tools ({[tool.name for tool in tools]})" + ) + return values + + def save(self, file_path: Union[Path, str]) -> None: + """Raise error - saving not supported for Agent Executors.""" + raise ValueError( + "Saving not supported for agent executors. " + "If you are trying to save the agent, please use the " + "`.save_agent(...)`" + ) + + def save_agent(self, file_path: Union[Path, str]) -> None: + """Save the underlying agent.""" + return self.agent.save(file_path) + @property def input_keys(self) -> List[str]: """Return the input keys. diff --git a/langchain/agents/conversational/base.py b/langchain/agents/conversational/base.py index 7ea6a268788..1cf24fa4649 100644 --- a/langchain/agents/conversational/base.py +++ b/langchain/agents/conversational/base.py @@ -18,6 +18,11 @@ class ConversationalAgent(Agent): ai_prefix: str = "AI" + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + return "conversational-react-description" + @property def observation_prefix(self) -> str: """Prefix to append the observation with.""" @@ -100,4 +105,7 @@ class ConversationalAgent(Agent): prompt=prompt, callback_manager=callback_manager, ) - return cls(llm_chain=llm_chain, ai_prefix=ai_prefix, **kwargs) + tool_names = [tool.name for tool in tools] + return cls( + llm_chain=llm_chain, allowed_tools=tool_names, ai_prefix=ai_prefix, **kwargs + ) diff --git a/langchain/agents/initialize.py b/langchain/agents/initialize.py new file mode 100644 index 00000000000..17d1cf53232 --- /dev/null +++ b/langchain/agents/initialize.py @@ -0,0 +1,68 @@ +"""Load agent.""" +from typing import Any, List, Optional + +from langchain.agents.agent import AgentExecutor +from langchain.agents.loading import AGENT_TO_CLASS, load_agent +from langchain.agents.tools import Tool +from langchain.callbacks.base import BaseCallbackManager +from langchain.llms.base import BaseLLM + + +def initialize_agent( + tools: List[Tool], + llm: BaseLLM, + agent: Optional[str] = None, + callback_manager: Optional[BaseCallbackManager] = None, + agent_path: Optional[str] = None, + **kwargs: Any, +) -> AgentExecutor: + """Load agent given tools and LLM. + + Args: + tools: List of tools this agent has access to. + llm: Language model to use as the agent. + agent: The agent to use. Valid options are: + `zero-shot-react-description` + `react-docstore` + `self-ask-with-search` + `conversational-react-description` + If None and agent_path is also None, will default to + `zero-shot-react-description`. + callback_manager: CallbackManager to use. Global callback manager is used if + not provided. Defaults to None. + agent_path: Path to serialized agent to use. + **kwargs: Additional key word arguments to pass to the agent. + + Returns: + An agent. + """ + if agent is None and agent_path is None: + agent = "zero-shot-react-description" + if agent is not None and agent_path is not None: + raise ValueError( + "Both `agent` and `agent_path` are specified, " + "but at most only one should be." + ) + if agent is not None: + if agent not in AGENT_TO_CLASS: + raise ValueError( + f"Got unknown agent type: {agent}. " + f"Valid types are: {AGENT_TO_CLASS.keys()}." + ) + agent_cls = AGENT_TO_CLASS[agent] + agent_obj = agent_cls.from_llm_and_tools( + llm, tools, callback_manager=callback_manager + ) + elif agent_path is not None: + agent_obj = load_agent(agent_path, callback_manager=callback_manager) + else: + raise ValueError( + "Somehow both `agent` and `agent_path` are None, " + "this should never happen." + ) + return AgentExecutor.from_agent_and_tools( + agent=agent_obj, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index bb76fb7ff9a..cb00d795801 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,14 +1,16 @@ -"""Load agent.""" -from typing import Any, List, Optional +"""Functionality for loading agents.""" +import json +from pathlib import Path +from typing import Any, Union -from langchain.agents.agent import AgentExecutor +import yaml + +from langchain.agents.agent import Agent from langchain.agents.conversational.base import ConversationalAgent from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent -from langchain.agents.tools import Tool -from langchain.callbacks.base import BaseCallbackManager -from langchain.llms.base import BaseLLM +from langchain.chains.loading import load_chain, load_chain_from_config AGENT_TO_CLASS = { "zero-shot-react-description": ZeroShotAgent, @@ -18,42 +20,41 @@ AGENT_TO_CLASS = { } -def initialize_agent( - tools: List[Tool], - llm: BaseLLM, - agent: str = "zero-shot-react-description", - callback_manager: Optional[BaseCallbackManager] = None, - **kwargs: Any, -) -> AgentExecutor: - """Load agent given tools and LLM. +def load_agent_from_config(config: dict, **kwargs: Any) -> Agent: + """Load agent from Config Dict.""" + if "_type" not in config: + raise ValueError("Must specify an agent Type in config") + config_type = config.pop("_type") - Args: - tools: List of tools this agent has access to. - llm: Language model to use as the agent. - agent: The agent to use. Valid options are: - `zero-shot-react-description` - `react-docstore` - `self-ask-with-search` - `conversational-react-description`. - callback_manager: CallbackManager to use. Global callback manager is used if - not provided. Defaults to None. - **kwargs: Additional key word arguments to pass to the agent. + if config_type not in AGENT_TO_CLASS: + raise ValueError(f"Loading {config_type} agent not supported") - Returns: - An agent. - """ - if agent not in AGENT_TO_CLASS: - raise ValueError( - f"Got unknown agent type: {agent}. " - f"Valid types are: {AGENT_TO_CLASS.keys()}." - ) - agent_cls = AGENT_TO_CLASS[agent] - agent_obj = agent_cls.from_llm_and_tools( - llm, tools, callback_manager=callback_manager - ) - return AgentExecutor.from_agent_and_tools( - agent=agent_obj, - tools=tools, - callback_manager=callback_manager, - **kwargs, - ) + agent_cls = AGENT_TO_CLASS[config_type] + if "llm_chain" in config: + config["llm_chain"] = load_chain_from_config(config.pop("llm_chain")) + elif "llm_chain_path" in config: + config["llm_chain"] = load_chain(config.pop("llm_chain_path")) + else: + raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.") + combined_config = {**config, **kwargs} + return agent_cls(**combined_config) # type: ignore + + +def load_agent(file: Union[str, Path], **kwargs: Any) -> Agent: + """Load agent from file.""" + # Convert file to Path object. + if isinstance(file, str): + file_path = Path(file) + else: + file_path = file + # Load from either json or yaml. + if file_path.suffix == ".json": + with open(file_path) as f: + config = json.load(f) + elif file_path.suffix == ".yaml": + with open(file_path, "r") as f: + config = yaml.safe_load(f) + else: + raise ValueError("File type must be json or yaml") + # Load the agent from the config now. + return load_agent_from_config(config, **kwargs) diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index a885f54d7bf..037367ed32b 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -49,6 +49,11 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]: class ZeroShotAgent(Agent): """Agent for the MRKL chain.""" + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + return "zero-shot-react-description" + @property def observation_prefix(self) -> str: """Prefix to append the observation with.""" diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index b6895205aba..e68dd45d6b4 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -17,6 +17,11 @@ from langchain.prompts.base import BasePromptTemplate class ReActDocstoreAgent(Agent, BaseModel): """Agent for the ReAct chain.""" + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + return "react-docstore" + @classmethod def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: """Return default prompt.""" diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index 4432c76fb14..e542d125a1d 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -12,6 +12,11 @@ from langchain.serpapi import SerpAPIWrapper class SelfAskWithSearchAgent(Agent): """Agent for the self-ask-with-search paper.""" + @property + def _agent_type(self) -> str: + """Return Identifier of agent type.""" + return "self-ask-with-search" + @classmethod def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate: """Prompt does not depend on tools."""