mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 12:39:32 +00:00
(wip) Harrison/serialize agents (#725)
This commit is contained in:
148
docs/modules/agents/examples/serialization.ipynb
Normal file
148
docs/modules/agents/examples/serialization.ipynb
Normal file
@@ -0,0 +1,148 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bfe18e28",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Serialization\n",
|
||||
"\n",
|
||||
"This notebook goes over how to serialize agents. For this notebook, it is important to understand the distinction we draw between `agents` and `tools`. An agent is the LLM powered decision maker that decides which actions to take and in which order. Tools are various instruments (functions) an agent has access to, through which an agent can interact with the outside world. When people generally use agents, they primarily talk about using an agent WITH tools. However, when we talk about serialization of agents, we are talking about the agent by itself. We plan to add support for serializing an agent WITH tools sometime in the future.\n",
|
||||
"\n",
|
||||
"Let's start by creating an agent with tools as we normally do:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "eb729f16",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents import load_tools\n",
|
||||
"from langchain.agents import initialize_agent\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
|
||||
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0578f566",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's now serialize the agent. To be explicit that we are serializing ONLY the agent, we will call the `save_agent` method."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "dc544de6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent.save_agent('agent.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "62dd45bf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\r\n",
|
||||
" \"llm_chain\": {\r\n",
|
||||
" \"memory\": null,\r\n",
|
||||
" \"verbose\": false,\r\n",
|
||||
" \"prompt\": {\r\n",
|
||||
" \"input_variables\": [\r\n",
|
||||
" \"input\",\r\n",
|
||||
" \"agent_scratchpad\"\r\n",
|
||||
" ],\r\n",
|
||||
" \"output_parser\": null,\r\n",
|
||||
" \"template\": \"Answer the following questions as best you can. You have access to the following tools:\\n\\nSearch: A search engine. Useful for when you need to answer questions about current events. Input should be a search query.\\nCalculator: Useful for when you need to answer questions about math.\\n\\nUse the following format:\\n\\nQuestion: the input question you must answer\\nThought: you should always think about what to do\\nAction: the action to take, should be one of [Search, Calculator]\\nAction Input: the input to the action\\nObservation: the result of the action\\n... (this Thought/Action/Action Input/Observation can repeat N times)\\nThought: I now know the final answer\\nFinal Answer: the final answer to the original input question\\n\\nBegin!\\n\\nQuestion: {input}\\nThought:{agent_scratchpad}\",\r\n",
|
||||
" \"template_format\": \"f-string\"\r\n",
|
||||
" },\r\n",
|
||||
" \"llm\": {\r\n",
|
||||
" \"model_name\": \"text-davinci-003\",\r\n",
|
||||
" \"temperature\": 0.0,\r\n",
|
||||
" \"max_tokens\": 256,\r\n",
|
||||
" \"top_p\": 1,\r\n",
|
||||
" \"frequency_penalty\": 0,\r\n",
|
||||
" \"presence_penalty\": 0,\r\n",
|
||||
" \"n\": 1,\r\n",
|
||||
" \"best_of\": 1,\r\n",
|
||||
" \"request_timeout\": null,\r\n",
|
||||
" \"logit_bias\": {},\r\n",
|
||||
" \"_type\": \"openai\"\r\n",
|
||||
" },\r\n",
|
||||
" \"output_key\": \"text\",\r\n",
|
||||
" \"_type\": \"llm_chain\"\r\n",
|
||||
" },\r\n",
|
||||
" \"return_values\": [\r\n",
|
||||
" \"output\"\r\n",
|
||||
" ],\r\n",
|
||||
" \"_type\": \"zero-shot-react-description\"\r\n",
|
||||
"}"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!cat agent.json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0eb72510",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can now load the agent back in"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "eb660b76",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent = initialize_agent(tools, llm, agent_path=\"agent.json\", verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aa624ea5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@@ -1,8 +1,9 @@
|
||||
"""Interface for agents."""
|
||||
from langchain.agents.agent import Agent, AgentExecutor
|
||||
from langchain.agents.conversational.base import ConversationalAgent
|
||||
from langchain.agents.initialize import initialize_agent
|
||||
from langchain.agents.load_tools import get_all_tool_names, load_tools
|
||||
from langchain.agents.loading import initialize_agent
|
||||
from langchain.agents.loading import load_agent
|
||||
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
@@ -21,4 +22,5 @@ __all__ = [
|
||||
"load_tools",
|
||||
"get_all_tool_names",
|
||||
"ConversationalAgent",
|
||||
"load_agent",
|
||||
]
|
||||
|
@@ -1,10 +1,13 @@
|
||||
"""Chain that takes in an input and produces an action and action input."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.agents.tools import Tool
|
||||
@@ -30,6 +33,7 @@ class Agent(BaseModel):
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
allowed_tools: List[str]
|
||||
return_values: List[str] = ["output"]
|
||||
|
||||
@abstractmethod
|
||||
@@ -146,7 +150,8 @@ class Agent(BaseModel):
|
||||
prompt=cls.create_prompt(tools),
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
@@ -192,6 +197,50 @@ class Agent(BaseModel):
|
||||
f"got {early_stopping_method}"
|
||||
)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = self._agent_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the agent.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the agent to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# If working with agent executor
|
||||
agent.agent.save(file_path="path/agent.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
agent_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(agent_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(agent_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class AgentExecutor(Chain, BaseModel):
|
||||
"""Consists of an agent using tools."""
|
||||
@@ -215,6 +264,30 @@ class AgentExecutor(Chain, BaseModel):
|
||||
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
def validate_tools(cls, values: Dict) -> Dict:
|
||||
"""Validate that tools are compatible with agent."""
|
||||
agent = values["agent"]
|
||||
tools = values["tools"]
|
||||
if set(agent.allowed_tools) != set([tool.name for tool in tools]):
|
||||
raise ValueError(
|
||||
f"Allowed tools ({agent.allowed_tools}) different than "
|
||||
f"provided tools ({[tool.name for tool in tools]})"
|
||||
)
|
||||
return values
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Raise error - saving not supported for Agent Executors."""
|
||||
raise ValueError(
|
||||
"Saving not supported for agent executors. "
|
||||
"If you are trying to save the agent, please use the "
|
||||
"`.save_agent(...)`"
|
||||
)
|
||||
|
||||
def save_agent(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the underlying agent."""
|
||||
return self.agent.save(file_path)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
@@ -18,6 +18,11 @@ class ConversationalAgent(Agent):
|
||||
|
||||
ai_prefix: str = "AI"
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
return "conversational-react-description"
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the observation with."""
|
||||
@@ -100,4 +105,7 @@ class ConversationalAgent(Agent):
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
return cls(llm_chain=llm_chain, ai_prefix=ai_prefix, **kwargs)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
return cls(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, ai_prefix=ai_prefix, **kwargs
|
||||
)
|
||||
|
68
langchain/agents/initialize.py
Normal file
68
langchain/agents/initialize.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Load agent."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
def initialize_agent(
|
||||
tools: List[Tool],
|
||||
llm: BaseLLM,
|
||||
agent: Optional[str] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
agent_path: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Load agent given tools and LLM.
|
||||
|
||||
Args:
|
||||
tools: List of tools this agent has access to.
|
||||
llm: Language model to use as the agent.
|
||||
agent: The agent to use. Valid options are:
|
||||
`zero-shot-react-description`
|
||||
`react-docstore`
|
||||
`self-ask-with-search`
|
||||
`conversational-react-description`
|
||||
If None and agent_path is also None, will default to
|
||||
`zero-shot-react-description`.
|
||||
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||
not provided. Defaults to None.
|
||||
agent_path: Path to serialized agent to use.
|
||||
**kwargs: Additional key word arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
An agent.
|
||||
"""
|
||||
if agent is None and agent_path is None:
|
||||
agent = "zero-shot-react-description"
|
||||
if agent is not None and agent_path is not None:
|
||||
raise ValueError(
|
||||
"Both `agent` and `agent_path` are specified, "
|
||||
"but at most only one should be."
|
||||
)
|
||||
if agent is not None:
|
||||
if agent not in AGENT_TO_CLASS:
|
||||
raise ValueError(
|
||||
f"Got unknown agent type: {agent}. "
|
||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||
)
|
||||
agent_cls = AGENT_TO_CLASS[agent]
|
||||
agent_obj = agent_cls.from_llm_and_tools(
|
||||
llm, tools, callback_manager=callback_manager
|
||||
)
|
||||
elif agent_path is not None:
|
||||
agent_obj = load_agent(agent_path, callback_manager=callback_manager)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Somehow both `agent` and `agent_path` are None, "
|
||||
"this should never happen."
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
@@ -1,14 +1,16 @@
|
||||
"""Load agent."""
|
||||
from typing import Any, List, Optional
|
||||
"""Functionality for loading agents."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
import yaml
|
||||
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.conversational.base import ConversationalAgent
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.chains.loading import load_chain, load_chain_from_config
|
||||
|
||||
AGENT_TO_CLASS = {
|
||||
"zero-shot-react-description": ZeroShotAgent,
|
||||
@@ -18,42 +20,41 @@ AGENT_TO_CLASS = {
|
||||
}
|
||||
|
||||
|
||||
def initialize_agent(
|
||||
tools: List[Tool],
|
||||
llm: BaseLLM,
|
||||
agent: str = "zero-shot-react-description",
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Load agent given tools and LLM.
|
||||
def load_agent_from_config(config: dict, **kwargs: Any) -> Agent:
|
||||
"""Load agent from Config Dict."""
|
||||
if "_type" not in config:
|
||||
raise ValueError("Must specify an agent Type in config")
|
||||
config_type = config.pop("_type")
|
||||
|
||||
Args:
|
||||
tools: List of tools this agent has access to.
|
||||
llm: Language model to use as the agent.
|
||||
agent: The agent to use. Valid options are:
|
||||
`zero-shot-react-description`
|
||||
`react-docstore`
|
||||
`self-ask-with-search`
|
||||
`conversational-react-description`.
|
||||
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||
not provided. Defaults to None.
|
||||
**kwargs: Additional key word arguments to pass to the agent.
|
||||
if config_type not in AGENT_TO_CLASS:
|
||||
raise ValueError(f"Loading {config_type} agent not supported")
|
||||
|
||||
Returns:
|
||||
An agent.
|
||||
"""
|
||||
if agent not in AGENT_TO_CLASS:
|
||||
raise ValueError(
|
||||
f"Got unknown agent type: {agent}. "
|
||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||
)
|
||||
agent_cls = AGENT_TO_CLASS[agent]
|
||||
agent_obj = agent_cls.from_llm_and_tools(
|
||||
llm, tools, callback_manager=callback_manager
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
agent_cls = AGENT_TO_CLASS[config_type]
|
||||
if "llm_chain" in config:
|
||||
config["llm_chain"] = load_chain_from_config(config.pop("llm_chain"))
|
||||
elif "llm_chain_path" in config:
|
||||
config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
|
||||
else:
|
||||
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
|
||||
combined_config = {**config, **kwargs}
|
||||
return agent_cls(**combined_config) # type: ignore
|
||||
|
||||
|
||||
def load_agent(file: Union[str, Path], **kwargs: Any) -> Agent:
|
||||
"""Load agent from file."""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix == ".json":
|
||||
with open(file_path) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix == ".yaml":
|
||||
with open(file_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError("File type must be json or yaml")
|
||||
# Load the agent from the config now.
|
||||
return load_agent_from_config(config, **kwargs)
|
||||
|
@@ -49,6 +49,11 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
|
||||
class ZeroShotAgent(Agent):
|
||||
"""Agent for the MRKL chain."""
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
return "zero-shot-react-description"
|
||||
|
||||
@property
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the observation with."""
|
||||
|
@@ -17,6 +17,11 @@ from langchain.prompts.base import BasePromptTemplate
|
||||
class ReActDocstoreAgent(Agent, BaseModel):
|
||||
"""Agent for the ReAct chain."""
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
return "react-docstore"
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
||||
"""Return default prompt."""
|
||||
|
@@ -12,6 +12,11 @@ from langchain.serpapi import SerpAPIWrapper
|
||||
class SelfAskWithSearchAgent(Agent):
|
||||
"""Agent for the self-ask-with-search paper."""
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
return "self-ask-with-search"
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
||||
"""Prompt does not depend on tools."""
|
||||
|
Reference in New Issue
Block a user