Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
37dc8a57b3 add memory to sql 2023-07-31 16:29:26 -07:00
2 changed files with 220 additions and 0 deletions

View File

@@ -0,0 +1,217 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "488d6ee8",
"metadata": {},
"source": [
"# Adding Memory to SQL Database Chain\n",
"\n",
"This notebook shows how to add memory to a SQLDatabaseChain."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6ef6918e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.utilities import SQLDatabase\n",
"from langchain_experimental.sql import SQLDatabaseChain"
]
},
{
"cell_type": "markdown",
"id": "600aedb5",
"metadata": {},
"source": [
"Set up the SQLDatabase and LLM"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b54c24c2",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../../../../notebooks/Chinook.db\")\n",
"llm = OpenAI(temperature=0, verbose=True)"
]
},
{
"cell_type": "markdown",
"id": "96a1543f",
"metadata": {},
"source": [
"Set up the memory"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fc103f91",
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory import ConversationBufferMemory\n",
"memory = ConversationBufferMemory()"
]
},
{
"cell_type": "markdown",
"id": "af31b91d",
"metadata": {},
"source": [
"Now we need add to a place for memory in the prompt template"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "debcff82",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"PROMPT_SUFFIX = \"\"\"Only use the following tables:\n",
"{table_info}\n",
"\n",
"Previous Conversation:\n",
"{history}\n",
"\n",
"Question: {input}\"\"\"\n",
"\n",
"_DEFAULT_TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
"\n",
"Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n",
"\n",
"Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"\n",
"Use the following format:\n",
"\n",
"Question: Question here\n",
"SQLQuery: SQL Query to run\n",
"SQLResult: Result of the SQLQuery\n",
"Answer: Final answer here\n",
"\n",
"\"\"\"\n",
"\n",
"PROMPT = PromptTemplate.from_template(\n",
" _DEFAULT_TEMPLATE + PROMPT_SUFFIX,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7f6115f4",
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, memory=memory)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b4753f69",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"name one employee\n",
"SQLQuery:\u001b[32;1m\u001b[1;3mSELECT FirstName, LastName FROM Employee LIMIT 1\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[('Andrew', 'Adams')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3mAndrew Adams\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Andrew Adams'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"name one employee\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "aa1100c8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"how many letters in their name?\n",
"SQLQuery:\u001b[32;1m\u001b[1;3mSELECT LENGTH(FirstName) + LENGTH(LastName) AS 'NameLength' FROM Employee WHERE FirstName = 'Andrew' AND LastName = 'Adams'\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(11,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3mAndrew Adams has 11 letters in their name.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Andrew Adams has 11 letters in their name.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"how many letters in their name?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11525db8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -112,6 +112,9 @@ class SQLDatabaseChain(Chain):
"table_info": table_info,
"stop": ["\nSQLResult:"],
}
if self.memory is not None:
for k in self.memory.memory_variables:
llm_inputs[k] = inputs[k]
intermediate_steps: List = []
try:
intermediate_steps.append(llm_inputs) # input: sql generation