Harrison/agent docs custom (#14877)

This commit is contained in:
Harrison Chase 2023-12-18 17:49:32 -08:00 committed by GitHub
parent 6fbd068b3f
commit f8dccaa027
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 321 additions and 93 deletions

View File

@ -9,119 +9,156 @@
"\n",
"This notebook goes through how to create your own custom agent.\n",
"\n",
"An agent consists of two parts:\n",
" \n",
" - Tools: The tools the agent has available to use.\n",
" - The agent class itself: this decides which action to take.\n",
" \n",
" \n",
"In this notebook we walk through how to create a custom agent."
"In this example, we will use OpenAI Function Calling to create this agent.\n",
"**This is generally the most reliable way to create agents.**\n",
"\n",
"We will first create it WITHOUT memory, but we will then show how to add memory in.\n",
"Memory is needed to enable conversation.\n",
"\n",
"First, let's load the language model we're going to use to control the agent."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9af9734e",
"id": "c712d4a9",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentExecutor, BaseSingleActionAgent, Tool\n",
"from langchain.utilities import SerpAPIWrapper"
"from langchain.chat_models import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "becda2a1",
"cell_type": "markdown",
"id": "c7121568",
"metadata": {},
"outputs": [],
"source": [
"search = SerpAPIWrapper()\n",
"tools = [\n",
" Tool(\n",
" name=\"Search\",\n",
" func=search.run,\n",
" description=\"useful for when you need to answer questions about current events\",\n",
" return_direct=True,\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a33e2f7e",
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, List, Tuple, Union\n",
"\n",
"from langchain.schema import AgentAction, AgentFinish\n",
"\n",
"\n",
"class FakeAgent(BaseSingleActionAgent):\n",
" \"\"\"Fake Custom Agent.\"\"\"\n",
"\n",
" @property\n",
" def input_keys(self):\n",
" return [\"input\"]\n",
"\n",
" def plan(\n",
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
" ) -> Union[AgentAction, AgentFinish]:\n",
" \"\"\"Given input, decided what to do.\n",
"\n",
" Args:\n",
" intermediate_steps: Steps the LLM has taken to date,\n",
" along with observations\n",
" **kwargs: User inputs.\n",
"\n",
" Returns:\n",
" Action specifying what tool to use.\n",
" \"\"\"\n",
" return AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\")\n",
"\n",
" async def aplan(\n",
" self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any\n",
" ) -> Union[AgentAction, AgentFinish]:\n",
" \"\"\"Given input, decided what to do.\n",
"\n",
" Args:\n",
" intermediate_steps: Steps the LLM has taken to date,\n",
" along with observations\n",
" **kwargs: User inputs.\n",
"\n",
" Returns:\n",
" Action specifying what tool to use.\n",
" \"\"\"\n",
" return AgentAction(tool=\"Search\", tool_input=kwargs[\"input\"], log=\"\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "655d72f6",
"metadata": {},
"outputs": [],
"source": [
"agent = FakeAgent()"
"Next, let's define some tools to use.\n",
"Let's write a really simple Python function to calculate the length of a word that is passed in."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "490604e9",
"id": "fbe32b5f",
"metadata": {},
"outputs": [],
"source": [
"agent_executor = AgentExecutor.from_agent_and_tools(\n",
" agent=agent, tools=tools, verbose=True\n",
"from langchain.agents import tool\n",
"\n",
"\n",
"@tool\n",
"def get_word_length(word: str) -> int:\n",
" \"\"\"Returns the length of a word.\"\"\"\n",
" return len(word)\n",
"\n",
"\n",
"tools = [get_word_length]"
]
},
{
"cell_type": "markdown",
"id": "ae021421",
"metadata": {},
"source": [
"Now let us create the prompt.\n",
"Because OpenAI Function Calling is finetuned for tool usage, we hardly need any instructions on how to reason, or how to output format.\n",
"We will just have two input variables: `input` and `agent_scratchpad`. `input` should be a string containing the user objective. `agent_scratchpad` should be a sequence of messages that contains the previous agent tool invocations and the corresponding tool outputs."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "aa4b50ea",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are very powerful assistant, but don't know current events\",\n",
" ),\n",
" (\"user\", \"{input}\"),\n",
" MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a7bc8eea",
"metadata": {},
"source": [
"How does the agent know what tools it can use?\n",
"In this case we're relying on OpenAI function calling LLMs, which take functions as a separate argument and have been specifically trained to know when to invoke those functions.\n",
"\n",
"To pass in our tools to the agent, we just need to format them to the OpenAI function format and pass them to our model. (By `bind`-ing the functions, we're making sure that they're passed in each time the model is invoked.)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e82713b6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.tools.render import format_tool_to_openai_function\n",
"\n",
"llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in tools])"
]
},
{
"cell_type": "markdown",
"id": "4565b5f2",
"metadata": {},
"source": [
"Putting those pieces together, we can now create the agent.\n",
"We will import two last utility functions: a component for formatting intermediate steps (agent action, tool output pairs) to input messages that can be sent to the model, and a component for converting the output message into an agent action/agent finish."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "925a8ca4",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents.format_scratchpad import format_to_openai_function_messages\n",
"from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser\n",
"\n",
"agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" }\n",
" | prompt\n",
" | llm_with_tools\n",
" | OpenAIFunctionsAgentOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9af9734e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentExecutor\n",
"\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "653b1617",
"metadata": {},
"outputs": [
@ -132,7 +169,11 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\u001b[36;1m\u001b[1;3mThe current population of Canada is 38,669,152 as of Monday, April 24, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `get_word_length` with `{'word': 'educa'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m5\u001b[0m\u001b[32;1m\u001b[1;3mThere are 5 letters in the word \"educa\".\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@ -140,22 +181,209 @@
{
"data": {
"text/plain": [
"'The current population of Canada is 38,669,152 as of Monday, April 24, 2023, based on Worldometer elaboration of the latest United Nations data.'"
"{'input': 'How many letters in the word educa',\n",
" 'output': 'There are 5 letters in the word \"educa\".'}"
]
},
"execution_count": 6,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.run(\"How many people live in canada as of 2023?\")"
"agent_executor.invoke({\"input\": \"How many letters in the word educa\"})"
]
},
{
"cell_type": "markdown",
"id": "e27d1f24",
"metadata": {},
"source": [
"If we compare this to the base LLM, we can see that the LLM alone struggles"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "60f5dc19",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='There are 6 letters in the word \"educa\".')"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.invoke(\"How many letters in the word educa\")"
]
},
{
"cell_type": "markdown",
"id": "6ade8e25",
"metadata": {},
"source": [
"## Adding memory\n",
"\n",
"This is great - we have an agent!\n",
"However, this agent is stateless - it doesn't remember anything about previous interactions.\n",
"This means you can't ask follow up questions easily.\n",
"Let's fix that by adding in memory.\n",
"\n",
"In order to do this, we need to do two things:\n",
"\n",
"1. Add a place for memory variables to go in the prompt\n",
"2. Keep track of the chat history\n",
"\n",
"First, let's add a place for memory in the prompt.\n",
"We do this by adding a placeholder for messages with the key `\"chat_history\"`.\n",
"Notice that we put this ABOVE the new user input (to follow the conversation flow)."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "169006d5",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import MessagesPlaceholder\n",
"\n",
"MEMORY_KEY = \"chat_history\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are very powerful assistant, but bad at calculating lengths of words.\",\n",
" ),\n",
" MessagesPlaceholder(variable_name=MEMORY_KEY),\n",
" (\"user\", \"{input}\"),\n",
" MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1b945a39",
"metadata": {},
"source": [
"We can then set up a list to track the chat history"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8c03f36c",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.messages import AIMessage, HumanMessage\n",
"\n",
"chat_history = []"
]
},
{
"cell_type": "markdown",
"id": "83541d01",
"metadata": {},
"source": [
"We can then put it all together!"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5429d97f",
"metadata": {},
"outputs": [],
"source": [
"agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" \"chat_history\": lambda x: x[\"chat_history\"],\n",
" }\n",
" | prompt\n",
" | llm_with_tools\n",
" | OpenAIFunctionsAgentOutputParser()\n",
")\n",
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "markdown",
"id": "3654b637",
"metadata": {},
"source": [
"When running, we now need to track the inputs and outputs as chat history\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "9d9da346",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `get_word_length` with `{'word': 'educa'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m5\u001b[0m\u001b[32;1m\u001b[1;3mThere are 5 letters in the word \"educa\".\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mNo, \"educa\" is not a real word in English.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'is that a real word?',\n",
" 'chat_history': [HumanMessage(content='how many letters in the word educa?'),\n",
" AIMessage(content='There are 5 letters in the word \"educa\".')],\n",
" 'output': 'No, \"educa\" is not a real word in English.'}"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input1 = \"how many letters in the word educa?\"\n",
"result = agent_executor.invoke({\"input\": input1, \"chat_history\": chat_history})\n",
"chat_history.extend(\n",
" [\n",
" HumanMessage(content=input1),\n",
" AIMessage(content=result[\"output\"]),\n",
" ]\n",
")\n",
"agent_executor.invoke({\"input\": \"is that a real word?\", \"chat_history\": chat_history})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adefb4c2",
"id": "f21bcd99",
"metadata": {},
"outputs": [],
"source": []
@ -177,7 +405,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.10.1"
},
"vscode": {
"interpreter": {

View File

@ -665,7 +665,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.1"
}
},
"nbformat": 4,