From fcb9b2ffe5e15a491da9715f8491cd1d745335e0 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 26 Nov 2022 06:06:44 -0800 Subject: [PATCH] Harrison/agent memory (#197) add doc for agent with memory --- docs/examples/memory/agent_with_memory.ipynb | 325 +++++++++++++++++++ langchain/chains/conversation/memory.py | 21 +- langchain/prompts/base.py | 16 +- langchain/prompts/prompt.py | 2 +- 4 files changed, 353 insertions(+), 11 deletions(-) create mode 100644 docs/examples/memory/agent_with_memory.ipynb diff --git a/docs/examples/memory/agent_with_memory.ipynb b/docs/examples/memory/agent_with_memory.ipynb new file mode 100644 index 00000000000..48c8533b26a --- /dev/null +++ b/docs/examples/memory/agent_with_memory.ipynb @@ -0,0 +1,325 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fa6802ac", + "metadata": {}, + "source": [ + "# Adding Memory to an Agent\n", + "\n", + "This notebook goes over adding memory to an Agent. Before going through this notebook, please walkthrough the following notebooks, as this will build on top of both of them:\n", + "\n", + "- [Adding memory to an LLM Chain](adding_memory.ipynb)\n", + "- [Custom Agents](../agents/custom_agent.ipynb)\n", + "\n", + "In order to add a memory to an agent we are going to the the following steps:\n", + "\n", + "1. We are going to create an LLMChain with memory.\n", + "2. We are going to use that LLMChain to create a custom Agent.\n", + "\n", + "For the purposes of this exercise, we are going to create a simple custom Agent that has access to a search tool and utilizes the `ConversationBufferMemory` class." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8db95912", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import ZeroShotAgent, Tool\n", + "from langchain.chains.conversation.memory import ConversationBufferMemory\n", + "from langchain import OpenAI, SerpAPIChain, LLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "97ad8467", + "metadata": {}, + "outputs": [], + "source": [ + "search = SerpAPIChain()\n", + "tools = [\n", + " Tool(\n", + " name = \"Search\",\n", + " func=search.run,\n", + " description=\"useful for when you need to answer questions about current events\"\n", + " )\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "4ad2e708", + "metadata": {}, + "source": [ + "Notice the usage of the `chat_history` variable in the PromptTemplate, which matches up with the dynamic key name in the ConversationBufferMemory." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e3439cd6", + "metadata": {}, + "outputs": [], + "source": [ + "prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n", + "suffix = \"\"\"Begin!\"\n", + "\n", + "{chat_history}\n", + "Question: {input}\"\"\"\n", + "\n", + "prompt = ZeroShotAgent.create_prompt(\n", + " tools, \n", + " prefix=prefix, \n", + " suffix=suffix, \n", + " input_variables=[\"input\", \"chat_history\"]\n", + ")\n", + "memory = ConversationBufferMemory(dynamic_key=\"chat_history\")" + ] + }, + { + "cell_type": "markdown", + "id": "0021675b", + "metadata": {}, + "source": [ + "We can now construct the LLMChain, with the Memory object, and then create the agent." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c56a0e73", + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt, memory=memory)\n", + "agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ca4bc1fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "How many people live in canada?\n", + "Thought:\u001b[32;1m\u001b[1;3m I should look up how many people live in canada\n", + "Action: Search\n", + "Action Input: \"How many people live in canada?\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: The current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data.'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"How many people live in canada?\")" + ] + }, + { + "cell_type": "markdown", + "id": "45627664", + "metadata": {}, + "source": [ + "To test the memory of this agent, we can ask a followup question that relies on information in the previous exchange to be answered correctly." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "eecc0462", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "what is their national anthem called?\n", + "Thought:\u001b[32;1m\u001b[1;3m\n", + "AI: I should look up the name of Canada's national anthem\n", + "Action: Search\n", + "Action Input: \"What is the name of Canada's national anthem?\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAfter 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m\n", + "AI: I now know the final answer\n", + "Final Answer: After 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa Lavallée.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"After 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa Lavallée.\"" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.run(\"what is their national anthem called?\")" + ] + }, + { + "cell_type": "markdown", + "id": "cc3d0aa4", + "metadata": {}, + "source": [ + "We can see that the agent remembered that the previous question was about Canada, and properly asked Google Search what the name of Canada's national anthem was.\n", + "\n", + "For fun, let's compare this to an agent that does NOT have memory." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3359d043", + "metadata": {}, + "outputs": [], + "source": [ + "prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n", + "suffix = \"\"\"Begin!\"\n", + "\n", + "Question: {input}\"\"\"\n", + "\n", + "prompt = ZeroShotAgent.create_prompt(\n", + " tools, \n", + " prefix=prefix, \n", + " suffix=suffix, \n", + " input_variables=[\"input\"]\n", + ")\n", + "llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)\n", + "agent_without_memory = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "970d23df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "How many people live in canada?\n", + "Thought:\u001b[32;1m\u001b[1;3m I should look up how many people live in canada\n", + "Action: Search\n", + "Action Input: \"How many people live in canada?\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: The current population of Canada is 38,533,678\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The current population of Canada is 38,533,678'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_without_memory.run(\"How many people live in canada?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d9ea82f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "what is their national anthem called?\n", + "Thought:\u001b[32;1m\u001b[1;3m I should probably look this up\n", + "Action: Search\n", + "Action Input: \"What is the national anthem of [country]\"\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mMost nation states have an anthem, defined as \"a song, as of praise, devotion, or patriotism\"; most anthems are either marches or hymns in style.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: The national anthem is called \"the national anthem.\"\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The national anthem is called \"the national anthem.\"'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_without_memory.run(\"what is their national anthem called?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b1f9223", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index 13250aae25a..b6bf9e2654a 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -10,6 +10,15 @@ from langchain.llms.base import LLM from langchain.prompts.base import BasePromptTemplate +def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: + # "stop" is a special key that can be passed as input but is not used to + # format the prompt. + prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"])) + if len(prompt_input_keys) != 1: + raise ValueError(f"One input key expected got {prompt_input_keys}") + return prompt_input_keys[0] + + class ConversationBufferMemory(Memory, BaseModel): """Buffer for storing conversation memory.""" @@ -30,12 +39,10 @@ class ConversationBufferMemory(Memory, BaseModel): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" - prompt_input_keys = list(set(inputs).difference(self.memory_variables)) - if len(prompt_input_keys) != 1: - raise ValueError(f"One input key expected got {prompt_input_keys}") + prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables) if len(outputs) != 1: raise ValueError(f"One output key expected, got {outputs.keys()}") - human = "Human: " + inputs[prompt_input_keys[0]] + human = "Human: " + inputs[prompt_input_key] ai = "AI: " + outputs[list(outputs.keys())[0]] self.buffer += "\n" + "\n".join([human, ai]) @@ -74,12 +81,10 @@ class ConversationSummaryMemory(Memory, BaseModel): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" - prompt_input_keys = list(set(inputs).difference(self.memory_variables)) - if len(prompt_input_keys) != 1: - raise ValueError(f"One input key expected got {prompt_input_keys}") + prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables) if len(outputs) != 1: raise ValueError(f"One output key expected, got {outputs.keys()}") - human = "Human: " + inputs[prompt_input_keys[0]] + human = "Human: " + inputs[prompt_input_key] ai = "AI: " + list(outputs.values())[0] new_lines = "\n".join([human, ai]) chain = LLMChain(llm=self.llm, prompt=self.prompt) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 44cdad9ff08..b44fec9a641 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,6 +1,8 @@ """BasePrompt schema definition.""" from abc import ABC, abstractmethod -from typing import Any, List +from typing import Any, Dict, List + +from pydantic import BaseModel, root_validator from langchain.formatting import formatter @@ -27,12 +29,22 @@ def check_valid_template( raise ValueError("Invalid prompt schema.") -class BasePromptTemplate(ABC): +class BasePromptTemplate(BaseModel, ABC): """Base prompt should expose the format method, returning a prompt.""" input_variables: List[str] """A list of the names of the variables the prompt template expects.""" + @root_validator() + def validate_variable_names(cls, values: Dict) -> Dict: + """Validate variable names do not restricted names.""" + if "stop" in values["input_variables"]: + raise ValueError( + "Cannot have an input variable named 'stop', as it is used internally," + " please rename." + ) + return values + @abstractmethod def format(self, **kwargs: Any) -> str: """Format the prompt with the inputs. diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 6eccaaa3975..bb1d331fb3e 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -10,7 +10,7 @@ from langchain.prompts.base import ( ) -class PromptTemplate(BaseModel, BasePromptTemplate): +class PromptTemplate(BasePromptTemplate, BaseModel): """Schema to represent a prompt for an LLM. Example: