mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
add custom prompt for LLMMathChain and SQLDatabase chain (#605)
This commit is contained in:
@@ -78,7 +78,7 @@
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees.\u001b[0m\n",
|
||||
"\u001b[1m> Finished SQLDatabaseChain chain.\u001b[0m\n"
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -96,10 +96,93 @@
|
||||
"db_chain.run(\"How many employees are there?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aad2cba6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Customize Prompt\n",
|
||||
"You can also customize the prompt that is used. Here is an example prompting it to understand that foobar is the same as the Employee table"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "8ca7bafb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"_DEFAULT_TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
|
||||
"Use the following format:\n",
|
||||
"\n",
|
||||
"Question: \"Question here\"\n",
|
||||
"SQLQuery: \"SQL Query to run\"\n",
|
||||
"SQLResult: \"Result of the SQLQuery\"\n",
|
||||
"Answer: \"Final answer here\"\n",
|
||||
"\n",
|
||||
"Only use the following tables:\n",
|
||||
"\n",
|
||||
"{table_info}\n",
|
||||
"\n",
|
||||
"If someone asks for the table foobar, they really mean the employee table.\n",
|
||||
"\n",
|
||||
"Question: {input}\"\"\"\n",
|
||||
"PROMPT = PromptTemplate(\n",
|
||||
" input_variables=[\"input\", \"table_info\", \"dialect\"], template=_DEFAULT_TEMPLATE\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "ec47a2bf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ebb0674e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
|
||||
"How many employees are there in the foobar table? \n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' There are 9 employees in the foobar table.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db_chain.run(\"How many employees are there in the foobar table?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "61d91b85",
|
||||
"id": "e59a4740",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
Reference in New Issue
Block a user