mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 18:53:10 +00:00
parent
6eab5254e5
commit
fcb9b2ffe5
325
docs/examples/memory/agent_with_memory.ipynb
Normal file
325
docs/examples/memory/agent_with_memory.ipynb
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "fa6802ac",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Adding Memory to an Agent\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over adding memory to an Agent. Before going through this notebook, please walkthrough the following notebooks, as this will build on top of both of them:\n",
|
||||||
|
"\n",
|
||||||
|
"- [Adding memory to an LLM Chain](adding_memory.ipynb)\n",
|
||||||
|
"- [Custom Agents](../agents/custom_agent.ipynb)\n",
|
||||||
|
"\n",
|
||||||
|
"In order to add a memory to an agent we are going to the the following steps:\n",
|
||||||
|
"\n",
|
||||||
|
"1. We are going to create an LLMChain with memory.\n",
|
||||||
|
"2. We are going to use that LLMChain to create a custom Agent.\n",
|
||||||
|
"\n",
|
||||||
|
"For the purposes of this exercise, we are going to create a simple custom Agent that has access to a search tool and utilizes the `ConversationBufferMemory` class."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "8db95912",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.agents import ZeroShotAgent, Tool\n",
|
||||||
|
"from langchain.chains.conversation.memory import ConversationBufferMemory\n",
|
||||||
|
"from langchain import OpenAI, SerpAPIChain, LLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "97ad8467",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"search = SerpAPIChain()\n",
|
||||||
|
"tools = [\n",
|
||||||
|
" Tool(\n",
|
||||||
|
" name = \"Search\",\n",
|
||||||
|
" func=search.run,\n",
|
||||||
|
" description=\"useful for when you need to answer questions about current events\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4ad2e708",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Notice the usage of the `chat_history` variable in the PromptTemplate, which matches up with the dynamic key name in the ConversationBufferMemory."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "e3439cd6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
|
||||||
|
"suffix = \"\"\"Begin!\"\n",
|
||||||
|
"\n",
|
||||||
|
"{chat_history}\n",
|
||||||
|
"Question: {input}\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = ZeroShotAgent.create_prompt(\n",
|
||||||
|
" tools, \n",
|
||||||
|
" prefix=prefix, \n",
|
||||||
|
" suffix=suffix, \n",
|
||||||
|
" input_variables=[\"input\", \"chat_history\"]\n",
|
||||||
|
")\n",
|
||||||
|
"memory = ConversationBufferMemory(dynamic_key=\"chat_history\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0021675b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We can now construct the LLMChain, with the Memory object, and then create the agent."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "c56a0e73",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt, memory=memory)\n",
|
||||||
|
"agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "ca4bc1fb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"How many people live in canada?\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m I should look up how many people live in canada\n",
|
||||||
|
"Action: Search\n",
|
||||||
|
"Action Input: \"How many people live in canada?\"\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||||
|
"Final Answer: The current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'The current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent.run(\"How many people live in canada?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "45627664",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To test the memory of this agent, we can ask a followup question that relies on information in the previous exchange to be answered correctly."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "eecc0462",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"what is their national anthem called?\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"AI: I should look up the name of Canada's national anthem\n",
|
||||||
|
"Action: Search\n",
|
||||||
|
"Action Input: \"What is the name of Canada's national anthem?\"\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3mAfter 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa ...\u001b[0m\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m\n",
|
||||||
|
"AI: I now know the final answer\n",
|
||||||
|
"Final Answer: After 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa Lavallée.\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"\"After 100 years of tradition, O Canada was proclaimed Canada's national anthem in 1980. The music for O Canada was composed in 1880 by Calixa Lavallée.\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent.run(\"what is their national anthem called?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "cc3d0aa4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We can see that the agent remembered that the previous question was about Canada, and properly asked Google Search what the name of Canada's national anthem was.\n",
|
||||||
|
"\n",
|
||||||
|
"For fun, let's compare this to an agent that does NOT have memory."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "3359d043",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
|
||||||
|
"suffix = \"\"\"Begin!\"\n",
|
||||||
|
"\n",
|
||||||
|
"Question: {input}\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = ZeroShotAgent.create_prompt(\n",
|
||||||
|
" tools, \n",
|
||||||
|
" prefix=prefix, \n",
|
||||||
|
" suffix=suffix, \n",
|
||||||
|
" input_variables=[\"input\"]\n",
|
||||||
|
")\n",
|
||||||
|
"llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)\n",
|
||||||
|
"agent_without_memory = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "970d23df",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"How many people live in canada?\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m I should look up how many people live in canada\n",
|
||||||
|
"Action: Search\n",
|
||||||
|
"Action Input: \"How many people live in canada?\"\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,533,678 as of Friday, November 25, 2022, based on Worldometer elaboration of the latest United Nations data. · Canada 2020 ...\u001b[0m\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||||
|
"Final Answer: The current population of Canada is 38,533,678\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'The current population of Canada is 38,533,678'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent_without_memory.run(\"How many people live in canada?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "d9ea82f0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||||
|
"what is their national anthem called?\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m I should probably look this up\n",
|
||||||
|
"Action: Search\n",
|
||||||
|
"Action Input: \"What is the national anthem of [country]\"\u001b[0m\n",
|
||||||
|
"Observation: \u001b[36;1m\u001b[1;3mMost nation states have an anthem, defined as \"a song, as of praise, devotion, or patriotism\"; most anthems are either marches or hymns in style.\u001b[0m\n",
|
||||||
|
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
|
||||||
|
"Final Answer: The national anthem is called \"the national anthem.\"\u001b[0m\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'The national anthem is called \"the national anthem.\"'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent_without_memory.run(\"what is their national anthem called?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5b1f9223",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -10,6 +10,15 @@ from langchain.llms.base import LLM
|
|||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
def _get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||||
|
# "stop" is a special key that can be passed as input but is not used to
|
||||||
|
# format the prompt.
|
||||||
|
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
|
||||||
|
if len(prompt_input_keys) != 1:
|
||||||
|
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
||||||
|
return prompt_input_keys[0]
|
||||||
|
|
||||||
|
|
||||||
class ConversationBufferMemory(Memory, BaseModel):
|
class ConversationBufferMemory(Memory, BaseModel):
|
||||||
"""Buffer for storing conversation memory."""
|
"""Buffer for storing conversation memory."""
|
||||||
|
|
||||||
@ -30,12 +39,10 @@ class ConversationBufferMemory(Memory, BaseModel):
|
|||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Save context from this conversation to buffer."""
|
"""Save context from this conversation to buffer."""
|
||||||
prompt_input_keys = list(set(inputs).difference(self.memory_variables))
|
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
|
||||||
if len(prompt_input_keys) != 1:
|
|
||||||
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
|
||||||
if len(outputs) != 1:
|
if len(outputs) != 1:
|
||||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||||
human = "Human: " + inputs[prompt_input_keys[0]]
|
human = "Human: " + inputs[prompt_input_key]
|
||||||
ai = "AI: " + outputs[list(outputs.keys())[0]]
|
ai = "AI: " + outputs[list(outputs.keys())[0]]
|
||||||
self.buffer += "\n" + "\n".join([human, ai])
|
self.buffer += "\n" + "\n".join([human, ai])
|
||||||
|
|
||||||
@ -74,12 +81,10 @@ class ConversationSummaryMemory(Memory, BaseModel):
|
|||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
"""Save context from this conversation to buffer."""
|
"""Save context from this conversation to buffer."""
|
||||||
prompt_input_keys = list(set(inputs).difference(self.memory_variables))
|
prompt_input_key = _get_prompt_input_key(inputs, self.memory_variables)
|
||||||
if len(prompt_input_keys) != 1:
|
|
||||||
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
|
||||||
if len(outputs) != 1:
|
if len(outputs) != 1:
|
||||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||||
human = "Human: " + inputs[prompt_input_keys[0]]
|
human = "Human: " + inputs[prompt_input_key]
|
||||||
ai = "AI: " + list(outputs.values())[0]
|
ai = "AI: " + list(outputs.values())[0]
|
||||||
new_lines = "\n".join([human, ai])
|
new_lines = "\n".join([human, ai])
|
||||||
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.formatting import formatter
|
from langchain.formatting import formatter
|
||||||
|
|
||||||
@ -27,12 +29,22 @@ def check_valid_template(
|
|||||||
raise ValueError("Invalid prompt schema.")
|
raise ValueError("Invalid prompt schema.")
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(ABC):
|
class BasePromptTemplate(BaseModel, ABC):
|
||||||
"""Base prompt should expose the format method, returning a prompt."""
|
"""Base prompt should expose the format method, returning a prompt."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
"""A list of the names of the variables the prompt template expects."""
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate variable names do not restricted names."""
|
||||||
|
if "stop" in values["input_variables"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have an input variable named 'stop', as it is used internally,"
|
||||||
|
" please rename."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
"""Format the prompt with the inputs.
|
"""Format the prompt with the inputs.
|
||||||
|
@ -10,7 +10,7 @@ from langchain.prompts.base import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate(BaseModel, BasePromptTemplate):
|
class PromptTemplate(BasePromptTemplate, BaseModel):
|
||||||
"""Schema to represent a prompt for an LLM.
|
"""Schema to represent a prompt for an LLM.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
Loading…
Reference in New Issue
Block a user