Files
langchain/docs/docs/how_to/sql_query_checking.ipynb
Ikko Eltociear Ashimine 7b1066341b docs: update sql_query_checking.ipynb (#23345)
creat -> create
2024-06-24 16:03:32 -04:00

403 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "4da7ae91-4973-4e97-a570-fa24024ec65d",
"metadata": {},
"source": [
"# How to do query validation as part of SQL question-answering\n",
"\n",
"Perhaps the most error-prone part of any SQL chain or agent is writing valid and safe SQL queries. In this guide we'll go over some strategies for validating our queries and handling invalid queries.\n",
"\n",
"We will cover: \n",
"\n",
"1. Appending a \"query validator\" step to the query generation;\n",
"2. Prompt engineering to reduce the incidence of errors.\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d40d5bc-3647-4b5d-808a-db470d40fe7a",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71f46270-e1c6-45b4-b36e-ea2e9f860eba",
"metadata": {},
"outputs": [],
"source": [
"# Uncomment the below to use LangSmith. Not required.\n",
"# import os\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"id": "a0a2151b-cecf-4559-92a1-ca48824fed18",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8cedc936-5268-4bfa-b838-bdcc1ee9573c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n",
"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\n"
]
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"print(db.run(\"SELECT * FROM Artist LIMIT 10;\"))"
]
},
{
"cell_type": "markdown",
"id": "2d203315-fab7-4621-80da-41e9bf82d803",
"metadata": {},
"source": [
"## Query checker\n",
"\n",
"Perhaps the simplest strategy is to ask the model itself to check the original query for common mistakes. Suppose we have the following SQL query chain:\n",
"\n",
"```{=mdx}\n",
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
"\n",
"<ChatModelTabs customVarName=\"llm\" />\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d81ebf69-75ad-4c92-baa9-fd152b8e622a",
"metadata": {},
"outputs": [],
"source": [
"# | output: false\n",
"# | echo: false\n",
"\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ec66bb76-b1ad-48ad-a7d4-b518e9421b86",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import create_sql_query_chain\n",
"\n",
"chain = create_sql_query_chain(llm, db)"
]
},
{
"cell_type": "markdown",
"id": "da01023d-cc05-43e3-a38d-ed9d56d3ad15",
"metadata": {},
"source": [
"And we want to validate its outputs. We can do so by extending the chain with a second prompt and model call:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "16686750-d8ee-4c60-8d67-b28281cb6164",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"system = \"\"\"Double check the user's {dialect} query for common mistakes, including:\n",
"- Using NOT IN with NULL values\n",
"- Using UNION when UNION ALL should have been used\n",
"- Using BETWEEN for exclusive ranges\n",
"- Data type mismatch in predicates\n",
"- Properly quoting identifiers\n",
"- Using the correct number of arguments for functions\n",
"- Casting to the correct data type\n",
"- Using the proper columns for joins\n",
"\n",
"If there are any of the above mistakes, rewrite the query.\n",
"If there are no mistakes, just reproduce the original query with no further commentary.\n",
"\n",
"Output the final SQL query only.\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", system), (\"human\", \"{query}\")]\n",
").partial(dialect=db.dialect)\n",
"validation_chain = prompt | llm | StrOutputParser()\n",
"\n",
"full_chain = {\"query\": chain} | validation_chain"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "28ef9c6e-21fa-4b62-8aa4-8cd398ce4c4d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT AVG(i.Total) AS AverageInvoice\n",
"FROM Invoice i\n",
"JOIN Customer c ON i.CustomerId = c.CustomerId\n",
"WHERE c.Country = 'USA'\n",
"AND c.Fax IS NULL\n",
"AND i.InvoiceDate >= '2003-01-01' \n",
"AND i.InvoiceDate < '2010-01-01'\n"
]
}
],
"source": [
"query = full_chain.invoke(\n",
" {\n",
" \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n",
" }\n",
")\n",
"print(query)"
]
},
{
"cell_type": "markdown",
"id": "228a1b87-4e44-4d86-bed7-fd2d7a91fb23",
"metadata": {},
"source": [
"Note how we can see both steps of the chain in the [Langsmith trace](https://smith.langchain.com/public/8a743295-a57c-4e4c-8625-bc7e36af9d74/r)."
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "d01d78b5-89a0-4c12-b743-707ebe64ba86",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(6.632999999999998,)]'"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "6e133526-26bd-49da-9cfa-7adc0e59fd72",
"metadata": {},
"source": [
"The obvious downside of this approach is that we need to make two model calls instead of one to generate our query. To get around this we can try to perform the query generation and query check in a single model invocation:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "7af0030a-549e-4e69-9298-3d0a038c2fdd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m System Message \u001b[0m================================\n",
"\n",
"You are a \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m expert. Given an input question, create a syntactically correct \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query to run.\n",
"Unless the user specifies in the question a specific number of examples to obtain, query for at most \u001b[33;1m\u001b[1;3m{top_k}\u001b[0m results using the LIMIT clause as per \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m. You can order the results to return the most informative data in the database.\n",
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
"\n",
"Only use the following tables:\n",
"\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n",
"\n",
"Write an initial draft of the query. Then double check the \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query for common mistakes, including:\n",
"- Using NOT IN with NULL values\n",
"- Using UNION when UNION ALL should have been used\n",
"- Using BETWEEN for exclusive ranges\n",
"- Data type mismatch in predicates\n",
"- Properly quoting identifiers\n",
"- Using the correct number of arguments for functions\n",
"- Casting to the correct data type\n",
"- Using the proper columns for joins\n",
"\n",
"Use format:\n",
"\n",
"First draft: <<FIRST_DRAFT_QUERY>>\n",
"Final answer: <<FINAL_ANSWER_QUERY>>\n",
"\n",
"\n",
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n"
]
}
],
"source": [
"system = \"\"\"You are a {dialect} expert. Given an input question, create a syntactically correct {dialect} query to run.\n",
"Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.\n",
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
"\n",
"Only use the following tables:\n",
"{table_info}\n",
"\n",
"Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:\n",
"- Using NOT IN with NULL values\n",
"- Using UNION when UNION ALL should have been used\n",
"- Using BETWEEN for exclusive ranges\n",
"- Data type mismatch in predicates\n",
"- Properly quoting identifiers\n",
"- Using the correct number of arguments for functions\n",
"- Casting to the correct data type\n",
"- Using the proper columns for joins\n",
"\n",
"Use format:\n",
"\n",
"First draft: <<FIRST_DRAFT_QUERY>>\n",
"Final answer: <<FINAL_ANSWER_QUERY>>\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", system), (\"human\", \"{input}\")]\n",
").partial(dialect=db.dialect)\n",
"\n",
"\n",
"def parse_final_answer(output: str) -> str:\n",
" return output.split(\"Final answer: \")[1]\n",
"\n",
"\n",
"chain = create_sql_query_chain(llm, db, prompt=prompt) | parse_final_answer\n",
"prompt.pretty_print()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "806e27a2-e511-45ea-a4ed-8ce8fa6e1d58",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"SELECT AVG(i.\"Total\") AS \"AverageInvoice\"\n",
"FROM \"Invoice\" i\n",
"JOIN \"Customer\" c ON i.\"CustomerId\" = c.\"CustomerId\"\n",
"WHERE c.\"Country\" = 'USA'\n",
"AND c.\"Fax\" IS NULL\n",
"AND i.\"InvoiceDate\" BETWEEN '2003-01-01' AND '2010-01-01';\n"
]
}
],
"source": [
"query = chain.invoke(\n",
" {\n",
" \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n",
" }\n",
")\n",
"print(query)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "70fff2fa-1f86-4f83-9fd2-e87a5234d329",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(6.632999999999998,)]'"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "fc8af115-7c23-421a-8fd7-29bf1b6687a4",
"metadata": {},
"source": [
"## Human-in-the-loop\n",
"\n",
"In some cases our data is sensitive enough that we never want to execute a SQL query without a human approving it first. Head to the [Tool use: Human-in-the-loop](/docs/how_to/tools_human) page to learn how to add a human-in-the-loop to any tool, chain or agent.\n",
"\n",
"## Error handling\n",
"\n",
"At some point, the model will make a mistake and craft an invalid SQL query. Or an issue will arise with our database. Or the model API will go down. We'll want to add some error handling behavior to our chains and agents so that we fail gracefully in these situations, and perhaps even automatically recover. To learn about error handling with tools, head to the [Tool use: Error handling](/docs/how_to/tools_error) page."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}