add custom prompt for LLMMathChain and SQLDatabase chain (#605)

This commit is contained in:
Harrison Chase
2023-01-13 06:28:51 -08:00
committed by GitHub
parent a87a2aacaa
commit 9f9afbb6a8
4 changed files with 192 additions and 8 deletions

View File

@@ -78,7 +78,7 @@
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees.\u001b[0m\n",
"\u001b[1m> Finished SQLDatabaseChain chain.\u001b[0m\n"
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
@@ -96,10 +96,93 @@
"db_chain.run(\"How many employees are there?\")"
]
},
{
"cell_type": "markdown",
"id": "aad2cba6",
"metadata": {},
"source": [
"## Customize Prompt\n",
"You can also customize the prompt that is used. Here is an example prompting it to understand that foobar is the same as the Employee table"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8ca7bafb",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"_DEFAULT_TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Use the following format:\n",
"\n",
"Question: \"Question here\"\n",
"SQLQuery: \"SQL Query to run\"\n",
"SQLResult: \"Result of the SQLQuery\"\n",
"Answer: \"Final answer here\"\n",
"\n",
"Only use the following tables:\n",
"\n",
"{table_info}\n",
"\n",
"If someone asks for the table foobar, they really mean the employee table.\n",
"\n",
"Question: {input}\"\"\"\n",
"PROMPT = PromptTemplate(\n",
" input_variables=[\"input\", \"table_info\", \"dialect\"], template=_DEFAULT_TEMPLATE\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ec47a2bf",
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ebb0674e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"How many employees are there in the foobar table? \n",
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' There are 9 employees in the foobar table.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"How many employees are there in the foobar table?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61d91b85",
"id": "e59a4740",
"metadata": {},
"outputs": [],
"source": []