core[patch], community[patch], langchain[patch], docs: Update SQL chains/agents/docs (#16168)

Revamp SQL use cases docs. In the process update SQL chains and agents.
This commit is contained in:
Bagatur 2024-01-22 08:19:08 -08:00 committed by GitHub
parent 05162928c0
commit 1dc6c1ce06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 3618 additions and 1361 deletions

View File

@ -78,7 +78,7 @@ Let models choose which tools to use given high-level directives
Walkthroughs and techniques for common end-to-end use cases, like:
- [Document question answering](/docs/use_cases/question_answering/)
- [Chatbots](/docs/use_cases/chatbots/)
- [Analyzing structured data](/docs/use_cases/qa_structured/sql/)
- [Analyzing structured data](/docs/use_cases/sql/)
- and much more...
### [Integrations](/docs/integrations/providers/)

View File

@ -597,6 +597,6 @@ To continue on your journey, we recommend you read the following (in order):
- [Model IO](/docs/modules/model_io) covers more details of prompts, LLMs, and output parsers.
- [Retrieval](/docs/modules/data_connection) covers more details of everything related to retrieval
- [Agents](/docs/modules/agents) covers details of everything related to agents
- Explore common [end-to-end use cases](/docs/use_cases/qa_structured/sql) and [template applications](/docs/templates)
- Explore common [end-to-end use cases](/docs/use_cases/) and [template applications](/docs/templates)
- [Read up on LangSmith](/docs/langsmith/), the platform for debugging, testing, monitoring and more
- Learn more about serving your applications with [LangServe](/docs/langserve)

View File

@ -33,7 +33,7 @@ db = SQLDatabase.from_uri(conn_str)
db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db, verbose=True)
```
From here, see the [SQL Chain](/docs/use_cases/tabular/sqlite) documentation on how to use.
From here, see the [SQL Chain](/docs/use_cases/sql/) documentation on how to use.
## LLMCache

View File

@ -1,2 +0,0 @@
label: 'Q&A over structured data'
position: 0.1

File diff suppressed because it is too large Load Diff

View File

@ -38,7 +38,7 @@
"\n",
"**Note**: Here we focus on Q&A for unstructured data. Two RAG use cases which we cover elsewhere are:\n",
"\n",
"- [Q&A over structured data](/docs/use_cases/qa_structured/sql) (e.g., SQL)\n",
"- [Q&A over SQL data](/docs/use_cases/sql/)\n",
"- [Q&A over code](/docs/use_cases/code_understanding) (e.g., Python)"
]
},
@ -103,7 +103,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -0,0 +1,815 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 1\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Agents\n",
"\n",
"LangChain has a SQL Agent which provides a more flexible way of interacting with SQL Databases than a chain. The main advantages of using the SQL Agent are:\n",
"\n",
"- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
"- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
"- It can query the database as many times as needed to answer the user question.\n",
"\n",
"To initialize the agent we'll use the [create_sql_agent](https://api.python.langchain.com/en/latest/agent_toolkits/langchain_community.agent_toolkits.sql.base.create_sql_agent.html) constructor. This agent uses the `SQLDatabaseToolkit` which contains tools to: \n",
"\n",
"* Create and execute queries\n",
"* Check query syntax\n",
"* Retrieve table descriptions\n",
"* ... and more\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
"\n",
"# Uncomment the below to use LangSmith. Not required.\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
]
},
{
"data": {
"text/plain": [
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent\n",
"\n",
"We'll use an OpenAI chat model and an `\"openai-tools\"` agent, which will use OpenAI's function-calling API to drive the agent's tool selection and invocations."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.agent_toolkits import create_sql_agent\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
"agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_list_tables` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_schema` with `Invoice,Customer`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"Customer\" (\n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
"\t\"Company\" NVARCHAR(80), \n",
"\t\"Address\" NVARCHAR(70), \n",
"\t\"City\" NVARCHAR(40), \n",
"\t\"State\" NVARCHAR(40), \n",
"\t\"Country\" NVARCHAR(40), \n",
"\t\"PostalCode\" NVARCHAR(10), \n",
"\t\"Phone\" NVARCHAR(24), \n",
"\t\"Fax\" NVARCHAR(24), \n",
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
"\t\"SupportRepId\" INTEGER, \n",
"\tPRIMARY KEY (\"CustomerId\"), \n",
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Customer table:\n",
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Invoice\" (\n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
"\t\"BillingAddress\" NVARCHAR(70), \n",
"\t\"BillingCity\" NVARCHAR(40), \n",
"\t\"BillingState\" NVARCHAR(40), \n",
"\t\"BillingCountry\" NVARCHAR(40), \n",
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceId\"), \n",
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Invoice table:\n",
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
"1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC`\n",
"responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n",
"\n",
"Here is the SQL query:\n",
"\n",
"```sql\n",
"SELECT c.Country, SUM(i.Total) AS TotalSales\n",
"FROM Invoice i\n",
"JOIN Customer c ON i.CustomerId = c.CustomerId\n",
"GROUP BY c.Country\n",
"ORDER BY TotalSales DESC\n",
"```\n",
"\n",
"Now, I will execute this query to get the results.\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.620000000000005), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.620000000000005), ('Poland', 37.620000000000005), ('Italy', 37.620000000000005), ('Denmark', 37.620000000000005), ('Australia', 37.620000000000005), ('Argentina', 37.620000000000005), ('Spain', 37.62), ('Belgium', 37.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n",
"\n",
"1. USA: $523.06\n",
"2. Canada: $303.96\n",
"3. France: $195.10\n",
"4. Brazil: $190.10\n",
"5. Germany: $156.48\n",
"6. United Kingdom: $112.86\n",
"7. Czech Republic: $90.24\n",
"8. Portugal: $77.24\n",
"9. India: $75.26\n",
"10. Chile: $46.62\n",
"\n",
"The country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
" 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nThe country whose customers spent the most is the USA, with a total sales of $523.06.'}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\n",
" \"List the total sales per country. Which country's customers spent the most?\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_list_tables` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_schema` with `PlaylistTrack`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from PlaylistTrack table:\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n",
"\n",
"Here is the schema of the `PlaylistTrack` table:\n",
"\n",
"```\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"```\n",
"\n",
"The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n",
"\n",
"Here are three sample rows from the `PlaylistTrack` table:\n",
"\n",
"```\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\n",
"```\n",
"\n",
"Please let me know if there is anything else I can help with.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'Describe the playlisttrack table',\n",
" 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\"Describe the playlisttrack table\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using a dynamic few-shot prompt\n",
"\n",
"To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input.\n",
"\n",
"First we need some user input <> SQL query examples:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"examples = [\n",
" {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n",
" {\n",
" \"input\": \"Find all albums for the artist 'AC/DC'.\",\n",
" \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n",
" },\n",
" {\n",
" \"input\": \"List all tracks in the 'Rock' genre.\",\n",
" \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n",
" },\n",
" {\n",
" \"input\": \"Find the total duration of all tracks.\",\n",
" \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n",
" },\n",
" {\n",
" \"input\": \"List all customers from Canada.\",\n",
" \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n",
" },\n",
" {\n",
" \"input\": \"How many tracks are there in the album with ID 5?\",\n",
" \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n",
" },\n",
" {\n",
" \"input\": \"Find the total number of invoices.\",\n",
" \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n",
" },\n",
" {\n",
" \"input\": \"List all tracks that are longer than 5 minutes.\",\n",
" \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n",
" },\n",
" {\n",
" \"input\": \"Who are the top 5 customers by total purchase?\",\n",
" \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n",
" },\n",
" {\n",
" \"input\": \"Which albums are from the year 2000?\",\n",
" \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n",
" },\n",
" {\n",
" \"input\": \"How many employees are there\",\n",
" \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n",
" },\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create an example selector. This will take the actual user input and select some number of examples to add to our few-shot prompt. We'll use a SemanticSimilarityExampleSelector, which will perform a semantic search using the embeddings and vector store we configure to find the examples most similar to our input:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import FAISS\n",
"from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"example_selector = SemanticSimilarityExampleSelector.from_examples(\n",
" examples,\n",
" OpenAIEmbeddings(),\n",
" FAISS,\n",
" k=5,\n",
" input_keys=[\"input\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can create our FewShotPromptTemplate, which takes our example selector, an example prompt for formatting each example, and a string prefix and suffix to put before and after our formatted examples:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import (\n",
" ChatPromptTemplate,\n",
" FewShotPromptTemplate,\n",
" MessagesPlaceholder,\n",
" PromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
")\n",
"\n",
"system_prefix = \"\"\"You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
"\n",
"Here are some examples of user inputs and their corresponding SQL queries:\"\"\"\n",
"\n",
"few_shot_prompt = FewShotPromptTemplate(\n",
" example_selector=example_selector,\n",
" example_prompt=PromptTemplate.from_template(\n",
" \"User input: {input}\\nSQL query: {query}\"\n",
" ),\n",
" input_variables=[\"input\", \"dialect\", \"top_k\"],\n",
" prefix=system_prefix,\n",
" suffix=\"\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since our underlying agent is an [OpenAI tools agent](/docs/modules/agents/agent_types/openai_tools), which uses OpenAI function calling, our full prompt should be a chat prompt with a human message template and an agent_scratchpad `MessagesPlaceholder`. The few-shot prompt will be used for our system message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"full_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" SystemMessagePromptTemplate(prompt=few_shot_prompt),\n",
" (\"human\", \"{input}\"),\n",
" MessagesPlaceholder(\"agent_scratchpad\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"System: You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
"\n",
"Here are some examples of user inputs and their corresponding SQL queries:\n",
"\n",
"User input: List all artists.\n",
"SQL query: SELECT * FROM Artist;\n",
"\n",
"User input: How many employees are there\n",
"SQL query: SELECT COUNT(*) FROM \"Employee\"\n",
"\n",
"User input: How many tracks are there in the album with ID 5?\n",
"SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n",
"\n",
"User input: List all tracks in the 'Rock' genre.\n",
"SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n",
"\n",
"User input: Which albums are from the year 2000?\n",
"SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n",
"Human: How many arists are there\n"
]
}
],
"source": [
"# Example formatted prompt\n",
"prompt_val = full_prompt.invoke(\n",
" {\n",
" \"input\": \"How many arists are there\",\n",
" \"top_k\": 5,\n",
" \"dialect\": \"SQLite\",\n",
" \"agent_scratchpad\": [],\n",
" }\n",
")\n",
"print(prompt_val.to_string())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we can create our agent with our custom prompt:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"agent = create_sql_agent(\n",
" llm=llm,\n",
" db=db,\n",
" prompt=full_prompt,\n",
" verbose=True,\n",
" agent_type=\"openai-tools\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try it out:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(275,)]\u001b[0m\u001b[32;1m\u001b[1;3mThere are 275 artists in the database.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'How many artists are there?',\n",
" 'output': 'There are 275 artists in the database.'}"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.invoke({\"input\": \"How many artists are there?\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dealing with high-cardinality columns\n",
"\n",
"In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
"\n",
"We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n",
"\n",
"First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['For Those About To Rock We Salute You',\n",
" 'Balls to the Wall',\n",
" 'Restless and Wild',\n",
" 'Let There Be Rock',\n",
" 'Big Ones']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ast\n",
"import re\n",
"\n",
"\n",
"def query_as_list(db, query):\n",
" res = db.run(query)\n",
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
" res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n",
" return res\n",
"\n",
"\n",
"artists = query_as_list(db, \"SELECT Name FROM Artist\")\n",
"albums = query_as_list(db, \"SELECT Title FROM Album\")\n",
"albums[:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can proceed with creating the custom **retriever tool** and the final agent:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents.agent_toolkits import create_retriever_tool\n",
"\n",
"vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())\n",
"retriever = vector_db.as_retriever(search_kwargs={\"k\": 5})\n",
"description = \"\"\"Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \\\n",
"valid proper nouns. Use the noun most similar to the search.\"\"\"\n",
"retriever_tool = create_retriever_tool(\n",
" retriever,\n",
" name=\"search_proper_nouns\",\n",
" description=description,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"system = \"\"\"You are an agent designed to interact with a SQL database.\n",
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
"You have access to tools for interacting with the database.\n",
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
"\n",
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
"\n",
"If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the \"search_proper_nouns\" tool!\n",
"\n",
"You have access to the following tables: {table_names}\n",
"\n",
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\"), MessagesPlaceholder(\"agent_scratchpad\")])\n",
"agent = create_sql_agent(\n",
" llm=llm,\n",
" db=db,\n",
" extra_tools=[retriever_tool],\n",
" prompt=prompt,\n",
" agent_type=\"openai-tools\",\n",
" verbose=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `search_proper_nouns` with `{'query': 'alice in chains'}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mAlice In Chains\n",
"\n",
"Metallica\n",
"\n",
"Pearl Jam\n",
"\n",
"Pearl Jam\n",
"\n",
"Smashing Pumpkins\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `{'query': \"SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')\"}`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[(1,)]\u001b[0m\u001b[32;1m\u001b[1;3mAlice In Chains has 1 album.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'How many albums does alice in chains have?',\n",
" 'output': 'Alice In Chains has 1 album.'}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.invoke({\"input\": \"How many albums does alice in chains have?\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the agent used the `search_proper_nouns` tool in order to check how to correctly query the database for this specific artist."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Next steps\n",
"\n",
"Under the hood, `create_sql_agent` is just passing in SQL tools to more generic agent constructors. To learn more about the built-in generic agent types as well as how to build custom agents, head to the [Agents Modules](/docs/modules/agents/).\n",
"\n",
"The built-in `AgentExecutor` runs a simple Agent action -> Tool call -> Agent action... loop. To build more complex agent runtimes, head to the [LangGraph section](/docs/langgraph)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,68 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 0.5\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SQL\n",
"\n",
"One of the most common types of databases that we can build Q&A systems for are SQL databases. LangChain comes with a number of built-in chains and agents that are compatible with any SQL dialect supported by SQLAlchemy (e.g., MySQL, PostgreSQL, Oracle SQL, Databricks, SQLite). They enable use cases such as:\n",
"\n",
"* Generating queries that will be run based on natural language questions,\n",
"* Creating chatbots that can answer questions based on database data,\n",
"* Building custom dashboards based on insights a user wants to analyze,\n",
"\n",
"and much more.\n",
"\n",
"## ⚠️ Security note ⚠️\n",
"\n",
"Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n",
"\n",
"![sql_usecase.png](../../../static/img/sql_usecase.png)\n",
"\n",
"## Quickstart\n",
"\n",
"Head to the **[Quickstart](/docs/use_cases/sql/quickstart)** page to get started.\n",
"\n",
"## Advanced\n",
"\n",
"Once you've familiarized yourself with the basics, you can head to the advanced guides:\n",
"\n",
"* [Agents](/docs/use_cases/sql/agents): Building agents that can interact with SQL DBs.\n",
"* [Prompting strategies](/docs/use_cases/sql/prompting): Strategies for improving SQL query generation.\n",
"* [Query validation](/docs/use_cases/sql/query_checking): How to validate SQL queries.\n",
"* [Large databases](/docs/use_cases/sql/large_db): How to interact with DBs with many tables and high-cardinality columns."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,627 @@
{
"cells": [
{
"cell_type": "raw",
"id": "b2788654-3f62-4e2a-ab00-471922cc54df",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 4\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "6751831d-9b08-434f-829b-d0052a3b119f",
"metadata": {},
"source": [
"# Large databases\n",
"\n",
"In order to write valid queries against a database, we need to feed the model the table names, table schemas, and feature values for it to query over. When there are many tables, columns, and/or high-cardinality columns, it becomes impossible for us to dump the full information about our database in every prompt. Instead, we must find ways to dynamically insert into the prompt only the most relevant information. Let's take a look at some techniques for doing this.\n",
"\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9675e433-e608-469e-b04e-2847479a8310",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
]
},
{
"cell_type": "markdown",
"id": "4f56ff5d-b2e4-49e3-a0b4-fb99466cfedc",
"metadata": {},
"source": [
"We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "06d8dd03-2d7b-4fef-b145-43c074eacb8b",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"source": [
"import getpass\n",
"import os\n",
"\n",
"# os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
"\n",
"# Uncomment the below to use LangSmith. Not required.\n",
"os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"id": "590ee096-db88-42af-90d4-99b8149df753",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html) class:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cebd3915-f58f-4e73-8459-265630ae8cd4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
]
},
{
"data": {
"text/plain": [
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
]
},
{
"cell_type": "markdown",
"id": "2e572e1f-99b5-46a2-9023-76d1e6256c0a",
"metadata": {},
"source": [
"## Many tables\n",
"\n",
"One of the main pieces of information we need to include in our prompt is the schemas of the relevant tables. When we have very many tables, we can't fit all of the schemas in a single prompt. What we can do in such cases is first extract the names of the tables related to the user input, and then include only their schemas.\n",
"\n",
"One easy and reliable way to do this is using OpenAI function-calling and Pydantic models. LangChain comes with a built-in [create_extraction_chain_pydantic](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_tools.extraction.create_extraction_chain_pydantic.html) chain that lets us do just this:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "d8236886-c54f-4bdb-ad74-2514888628fd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Table(name='Genre'), Table(name='Artist'), Table(name='Track')]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains.openai_tools import create_extraction_chain_pydantic\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo-1106\", temperature=0)\n",
"\n",
"\n",
"class Table(BaseModel):\n",
" \"\"\"Table in SQL database.\"\"\"\n",
"\n",
" name: str = Field(description=\"Name of table in SQL database.\")\n",
"\n",
"\n",
"table_names = \"\\n\".join(db.get_usable_table_names())\n",
"system = f\"\"\"Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \\\n",
"The tables are:\n",
"\n",
"{table_names}\n",
"\n",
"Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.\"\"\"\n",
"table_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n",
"table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})"
]
},
{
"cell_type": "markdown",
"id": "1641dbba-d359-4cb2-ac52-82dfae99f392",
"metadata": {},
"source": [
"This works pretty well! Except, as we'll see below, we actually need a few other tables as well. This would be pretty difficult for the model to know based just on the user question. In this case, we might think to simplify our model's job by grouping the tables together. We'll just ask the model to choose between categories \"Music\" and \"Business\", and then take care of selecting all the relevant tables from there:"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "0ccb0bf5-c580-428f-9cde-a58772ae784e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Table(name='Music')]"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"system = f\"\"\"Return the names of the SQL tables that are relevant to the user question. \\\n",
"The tables are:\n",
"\n",
"Music\n",
"Business\"\"\"\n",
"category_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)\n",
"category_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "ae4899fc-6f8a-4b10-983c-9e3fef4a7bb9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import List\n",
"\n",
"\n",
"def get_tables(categories: List[Table]) -> List[str]:\n",
" tables = []\n",
" for category in categories:\n",
" if category.name == \"Music\":\n",
" tables.extend(\n",
" [\n",
" \"Album\",\n",
" \"Artist\",\n",
" \"Genre\",\n",
" \"MediaType\",\n",
" \"Playlist\",\n",
" \"PlaylistTrack\",\n",
" \"Track\",\n",
" ]\n",
" )\n",
" elif category.name == \"Business\":\n",
" tables.extend([\"Customer\", \"Employee\", \"Invoice\", \"InvoiceLine\"])\n",
" return tables\n",
"\n",
"\n",
"table_chain = category_chain | get_tables # noqa\n",
"table_chain.invoke({\"input\": \"What are all the genres of Alanis Morisette songs\"})"
]
},
{
"cell_type": "markdown",
"id": "04d52d01-1ccf-4753-b34a-0dcbc4921f78",
"metadata": {},
"source": [
"Now that we've got a chain that can output the relevant tables for any query we can combine this with our [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html), which can accept a list of `table_names_to_use` to determine which table schemas are included in the prompt:"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "79f2a5a2-eb99-47e3-9c2b-e5751a800174",
"metadata": {},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"\n",
"from langchain.chains import create_sql_query_chain\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"query_chain = create_sql_query_chain(llm, db)\n",
"# Convert \"question\" key to the \"input\" key expected by current table_chain.\n",
"table_chain = {\"input\": itemgetter(\"question\")} | table_chain\n",
"# Set table_names_to_use using table_chain.\n",
"full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "424a7564-f63c-4584-b734-88021926486d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT \"Genre\".\"Name\"\n",
"FROM \"Genre\"\n",
"JOIN \"Track\" ON \"Genre\".\"GenreId\" = \"Track\".\"GenreId\"\n",
"JOIN \"Album\" ON \"Track\".\"AlbumId\" = \"Album\".\"AlbumId\"\n",
"JOIN \"Artist\" ON \"Album\".\"ArtistId\" = \"Artist\".\"ArtistId\"\n",
"WHERE \"Artist\".\"Name\" = 'Alanis Morissette'\n"
]
}
],
"source": [
"query = full_chain.invoke(\n",
" {\"question\": \"What are all the genres of Alanis Morisette songs\"}\n",
")\n",
"print(query)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "3fb715cf-69d1-46a6-a1a7-9715ee550a0c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"[('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',)]\""
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "bb3d12b0-81a6-4250-8bc4-d58fe762c4cc",
"metadata": {},
"source": [
"We might rephrase our question slightly to remove redundancy in the answer"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "010b5c3c-d55b-461a-8de5-8f1a8b2c56ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT DISTINCT g.Name\n",
"FROM Genre g\n",
"JOIN Track t ON g.GenreId = t.GenreId\n",
"JOIN Album a ON t.AlbumId = a.AlbumId\n",
"JOIN Artist ar ON a.ArtistId = ar.ArtistId\n",
"WHERE ar.Name = 'Alanis Morissette'\n"
]
}
],
"source": [
"query = full_chain.invoke(\n",
" {\"question\": \"What is the set of all unique genres of Alanis Morisette songs\"}\n",
")\n",
"print(query)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "d21c0563-1f55-4577-8222-b0e9802f1c4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"[('Rock',)]\""
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "7a717020-84c2-40f3-ba84-6624138d8e0c",
"metadata": {},
"source": [
"We can see the [LangSmith trace](https://smith.langchain.com/public/20b8ef90-1dac-4754-90f0-6bc11203c50a/r) for this run here.\n",
"\n",
"We've seen how to dynamically include a subset of table schemas in a prompt within a chain. Another possible approach to this problem is to let an Agent decide for itself when to look up tables by giving it a Tool to do so. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide."
]
},
{
"cell_type": "markdown",
"id": "cb9e54fd-64ca-4ed5-847c-afc635aae4f5",
"metadata": {},
"source": [
"## High-cardinality columns\n",
"\n",
"In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
"\n",
"One naive strategy it to create a vector store with all the distinct proper nouns that exist in the database. We can then query that vector store each user input and inject the most relevant proper nouns into the prompt.\n",
"\n",
"First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dee1b9e1-36b0-4cc1-ab78-7a872ad87e29",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ast\n",
"import re\n",
"\n",
"\n",
"def query_as_list(db, query):\n",
" res = db.run(query)\n",
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
" res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n",
" return res\n",
"\n",
"\n",
"proper_nouns = query_as_list(db, \"SELECT Name FROM Artist\")\n",
"proper_nouns += query_as_list(db, \"SELECT Title FROM Album\")\n",
"proper_nouns += query_as_list(db, \"SELECT Name FROM Genre\")\n",
"len(proper_nouns)\n",
"proper_nouns[:5]"
]
},
{
"cell_type": "markdown",
"id": "22efa968-1879-4d7a-858f-7899dfa57454",
"metadata": {},
"source": [
"Now we can embed and store all of our values in a vector database:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ea50abce-545a-4dc3-8795-8d364f7d142a",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import FAISS\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())\n",
"retriever = vector_db.as_retriever(search_kwargs={\"k\": 15})"
]
},
{
"cell_type": "markdown",
"id": "a5d1d5c0-0928-40a4-b961-f1afe03cd5d3",
"metadata": {},
"source": [
"And put together a query construction chain that first retrieves values from the database and inserts them into the prompt:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "aea123ae-d809-44a0-be5d-d883c60d6a11",
"metadata": {},
"outputs": [],
"source": [
"from operator import itemgetter\n",
"\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"system = \"\"\"You are a SQLite expert. Given an input question, create a syntactically \\\n",
"correct SQLite query to run. Unless otherwise specificed, do not return more than \\\n",
"{top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nHere is a non-exhaustive \\\n",
"list of possible feature values. If filtering on a feature value make sure to check its spelling \\\n",
"against this list first:\\n\\n{proper_nouns}\"\"\"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\")])\n",
"\n",
"query_chain = create_sql_query_chain(llm, db, prompt=prompt)\n",
"retriever_chain = (\n",
" itemgetter(\"question\")\n",
" | retriever\n",
" | (lambda docs: \"\\n\".join(doc.page_content for doc in docs))\n",
")\n",
"chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain"
]
},
{
"cell_type": "markdown",
"id": "12b0ed60-2536-4f82-85df-e096a272072a",
"metadata": {},
"source": [
"To try out our chain, let's see what happens when we try filtering on \"elenis moriset\", a mispelling of Alanis Morissette, without and with retrieval:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "fcdd8432-07a4-4609-8214-b1591dd94950",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT DISTINCT Genre.Name\n",
"FROM Genre\n",
"JOIN Track ON Genre.GenreId = Track.GenreId\n",
"JOIN Album ON Track.AlbumId = Album.AlbumId\n",
"JOIN Artist ON Album.ArtistId = Artist.ArtistId\n",
"WHERE Artist.Name = 'Elenis Moriset'\n"
]
},
{
"data": {
"text/plain": [
"''"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Without retrieval\n",
"query = query_chain.invoke(\n",
" {\"question\": \"What are all the genres of elenis moriset songs\", \"proper_nouns\": \"\"}\n",
")\n",
"print(query)\n",
"db.run(query)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e8a3231a-8590-46f5-a954-da06829ee6df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SELECT DISTINCT Genre.Name\n",
"FROM Genre\n",
"JOIN Track ON Genre.GenreId = Track.GenreId\n",
"JOIN Album ON Track.AlbumId = Album.AlbumId\n",
"JOIN Artist ON Album.ArtistId = Artist.ArtistId\n",
"WHERE Artist.Name = 'Alanis Morissette'\n"
]
},
{
"data": {
"text/plain": [
"\"[('Rock',)]\""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# With retrieval\n",
"query = chain.invoke({\"question\": \"What are all the genres of elenis moriset songs\"})\n",
"print(query)\n",
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "7f99181b-a75c-4ff3-b37b-33f99a506581",
"metadata": {},
"source": [
"We can see that with retrieval we're able to correct the spelling and get back a valid result.\n",
"\n",
"Another possible approach to this problem is to let an Agent decide for itself when to look up proper nouns. You can see an example of this in the [SQL: Agents](/docs/use_cases/sql/agents) guide."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,789 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 2\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prompting strategies\n",
"\n",
"In this guide we'll go over prompting strategies to improve SQL query generation. We'll largely focus on methods for getting relevant database-specific information in your prompt.\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-experimental langchain-openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
"\n",
"# Uncomment the below to use LangSmith. Not required.\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
]
},
{
"data": {
"text/plain": [
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\", sample_rows_in_table_info=3)\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dialect-specific prompting\n",
"\n",
"One of the simplest things we can do is make our prompt specific to the SQL dialect we're using. When using the built-in [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html) and [SQLDatabase](https://api.python.langchain.com/en/latest/utilities/langchain_community.utilities.sql_database.SQLDatabase.html), this is handled for you for any of the following dialects:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['crate',\n",
" 'duckdb',\n",
" 'googlesql',\n",
" 'mssql',\n",
" 'mysql',\n",
" 'mariadb',\n",
" 'oracle',\n",
" 'postgresql',\n",
" 'sqlite',\n",
" 'clickhouse',\n",
" 'prestodb']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains.sql_database.prompt import SQL_PROMPTS\n",
"\n",
"list(SQL_PROMPTS)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For example, using our current DB we can see that we'll get a SQLite-specific prompt:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n",
"Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n",
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
"\n",
"Use the following format:\n",
"\n",
"Question: Question here\n",
"SQLQuery: SQL Query to run\n",
"SQLResult: Result of the SQLQuery\n",
"Answer: Final answer here\n",
"\n",
"Only use the following tables:\n",
"\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n",
"\n",
"Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n"
]
}
],
"source": [
"from langchain.chains import create_sql_query_chain\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=\"0\")\n",
"chain = create_sql_query_chain(llm, db)\n",
"chain.get_prompts()[0].pretty_print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Table definitions and example rows\n",
"\n",
"In basically any SQL chain, we'll need to feed the model at least part of the database schema. Without this it won't be able to write valid queries. Our database comes with some convenience methods to give us the relevant context. Specifically, we can get the table names, their schemas, and a sample of rows from each table:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['table_info', 'table_names']\n",
"\n",
"CREATE TABLE \"Album\" (\n",
"\t\"AlbumId\" INTEGER NOT NULL, \n",
"\t\"Title\" NVARCHAR(160) NOT NULL, \n",
"\t\"ArtistId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"AlbumId\"), \n",
"\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Album table:\n",
"AlbumId\tTitle\tArtistId\n",
"1\tFor Those About To Rock We Salute You\t1\n",
"2\tBalls to the Wall\t2\n",
"3\tRestless and Wild\t2\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Artist\" (\n",
"\t\"ArtistId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(120), \n",
"\tPRIMARY KEY (\"ArtistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Artist table:\n",
"ArtistId\tName\n",
"1\tAC/DC\n",
"2\tAccept\n",
"3\tAerosmith\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Customer\" (\n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
"\t\"Company\" NVARCHAR(80), \n",
"\t\"Address\" NVARCHAR(70), \n",
"\t\"City\" NVARCHAR(40), \n",
"\t\"State\" NVARCHAR(40), \n",
"\t\"Country\" NVARCHAR(40), \n",
"\t\"PostalCode\" NVARCHAR(10), \n",
"\t\"Phone\" NVARCHAR(24), \n",
"\t\"Fax\" NVARCHAR(24), \n",
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
"\t\"SupportRepId\" INTEGER, \n",
"\tPRIMARY KEY (\"CustomerId\"), \n",
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Customer table:\n",
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Employee\" (\n",
"\t\"EmployeeId\" INTEGER NOT NULL, \n",
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
"\t\"FirstName\" NVARCHAR(20) NOT NULL, \n",
"\t\"Title\" NVARCHAR(30), \n",
"\t\"ReportsTo\" INTEGER, \n",
"\t\"BirthDate\" DATETIME, \n",
"\t\"HireDate\" DATETIME, \n",
"\t\"Address\" NVARCHAR(70), \n",
"\t\"City\" NVARCHAR(40), \n",
"\t\"State\" NVARCHAR(40), \n",
"\t\"Country\" NVARCHAR(40), \n",
"\t\"PostalCode\" NVARCHAR(10), \n",
"\t\"Phone\" NVARCHAR(24), \n",
"\t\"Fax\" NVARCHAR(24), \n",
"\t\"Email\" NVARCHAR(60), \n",
"\tPRIMARY KEY (\"EmployeeId\"), \n",
"\tFOREIGN KEY(\"ReportsTo\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Employee table:\n",
"EmployeeId\tLastName\tFirstName\tTitle\tReportsTo\tBirthDate\tHireDate\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\n",
"1\tAdams\tAndrew\tGeneral Manager\tNone\t1962-02-18 00:00:00\t2002-08-14 00:00:00\t11120 Jasper Ave NW\tEdmonton\tAB\tCanada\tT5K 2N1\t+1 (780) 428-9482\t+1 (780) 428-3457\tandrew@chinookcorp.com\n",
"2\tEdwards\tNancy\tSales Manager\t1\t1958-12-08 00:00:00\t2002-05-01 00:00:00\t825 8 Ave SW\tCalgary\tAB\tCanada\tT2P 2T3\t+1 (403) 262-3443\t+1 (403) 262-3322\tnancy@chinookcorp.com\n",
"3\tPeacock\tJane\tSales Support Agent\t2\t1973-08-29 00:00:00\t2002-04-01 00:00:00\t1111 6 Ave SW\tCalgary\tAB\tCanada\tT2P 5M5\t+1 (403) 262-3443\t+1 (403) 262-6712\tjane@chinookcorp.com\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Genre\" (\n",
"\t\"GenreId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(120), \n",
"\tPRIMARY KEY (\"GenreId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Genre table:\n",
"GenreId\tName\n",
"1\tRock\n",
"2\tJazz\n",
"3\tMetal\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Invoice\" (\n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
"\t\"BillingAddress\" NVARCHAR(70), \n",
"\t\"BillingCity\" NVARCHAR(40), \n",
"\t\"BillingState\" NVARCHAR(40), \n",
"\t\"BillingCountry\" NVARCHAR(40), \n",
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceId\"), \n",
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Invoice table:\n",
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
"1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"InvoiceLine\" (\n",
"\t\"InvoiceLineId\" INTEGER NOT NULL, \n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
"\t\"Quantity\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceLineId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from InvoiceLine table:\n",
"InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n",
"1\t1\t2\t0.99\t1\n",
"2\t1\t4\t0.99\t1\n",
"3\t2\t6\t0.99\t1\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"MediaType\" (\n",
"\t\"MediaTypeId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(120), \n",
"\tPRIMARY KEY (\"MediaTypeId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from MediaType table:\n",
"MediaTypeId\tName\n",
"1\tMPEG audio file\n",
"2\tProtected AAC audio file\n",
"3\tProtected MPEG-4 video file\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Playlist\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(120), \n",
"\tPRIMARY KEY (\"PlaylistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Playlist table:\n",
"PlaylistId\tName\n",
"1\tMusic\n",
"2\tMovies\n",
"3\tTV Shows\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from PlaylistTrack table:\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Track\" (\n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(200) NOT NULL, \n",
"\t\"AlbumId\" INTEGER, \n",
"\t\"MediaTypeId\" INTEGER NOT NULL, \n",
"\t\"GenreId\" INTEGER, \n",
"\t\"Composer\" NVARCHAR(220), \n",
"\t\"Milliseconds\" INTEGER NOT NULL, \n",
"\t\"Bytes\" INTEGER, \n",
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"TrackId\"), \n",
"\tFOREIGN KEY(\"MediaTypeId\") REFERENCES \"MediaType\" (\"MediaTypeId\"), \n",
"\tFOREIGN KEY(\"GenreId\") REFERENCES \"Genre\" (\"GenreId\"), \n",
"\tFOREIGN KEY(\"AlbumId\") REFERENCES \"Album\" (\"AlbumId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Track table:\n",
"TrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n",
"1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n",
"2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n",
"3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99\n",
"*/\n"
]
}
],
"source": [
"context = db.get_context()\n",
"print(list(context))\n",
"print(context[\"table_info\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we don't have too many, or too wide of, tables, we can just insert the entirety of this information in our prompt:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n",
"Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n",
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
"\n",
"Use the following format:\n",
"\n",
"Question: Question here\n",
"SQLQuery: SQL Query to run\n",
"SQLResult: Result of the SQLQuery\n",
"Answer: Final answer here\n",
"\n",
"Only use the following tables:\n",
"\n",
"CREATE TABLE \"Album\" (\n",
"\t\"AlbumId\" INTEGER NOT NULL, \n",
"\t\"Title\" NVARCHAR(160) NOT NULL, \n",
"\t\"ArtistId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"AlbumId\"), \n",
"\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Album table:\n",
"AlbumId\tTitle\tArtistId\n",
"1\tFor Those About To Rock We Salute You\t1\n",
"2\tBalls to the Wall\t2\n",
"3\tRestless and Wild\t2\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Artist\" (\n",
"\t\"ArtistId\" INTEGER NOT NULL, \n",
"\t\"Name\" NVARCHAR(120)\n"
]
}
],
"source": [
"prompt_with_context = chain.get_prompts()[0].partial(table_info=context[\"table_info\"])\n",
"print(prompt_with_context.pretty_repr()[:1500])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we do have database schemas that are too large to fit into our model's context window, we'll need to come up with ways of inserting only the relevant table definitions into the prompt based on the user input. For more on this head to the [Many tables, wide tables, high-cardinality feature](/docs/use_cases/sql/large_db) guide."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Few-shot examples\n",
"\n",
"Including examples of natural language questions being converted to valid SQL queries against our database in the prompt will often improve model performance, especially for complex queries.\n",
"\n",
"Let's say we have the following examples:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"examples = [\n",
" {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n",
" {\n",
" \"input\": \"Find all albums for the artist 'AC/DC'.\",\n",
" \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n",
" },\n",
" {\n",
" \"input\": \"List all tracks in the 'Rock' genre.\",\n",
" \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n",
" },\n",
" {\n",
" \"input\": \"Find the total duration of all tracks.\",\n",
" \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n",
" },\n",
" {\n",
" \"input\": \"List all customers from Canada.\",\n",
" \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n",
" },\n",
" {\n",
" \"input\": \"How many tracks are there in the album with ID 5?\",\n",
" \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n",
" },\n",
" {\n",
" \"input\": \"Find the total number of invoices.\",\n",
" \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n",
" },\n",
" {\n",
" \"input\": \"List all tracks that are longer than 5 minutes.\",\n",
" \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n",
" },\n",
" {\n",
" \"input\": \"Who are the top 5 customers by total purchase?\",\n",
" \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n",
" },\n",
" {\n",
" \"input\": \"Which albums are from the year 2000?\",\n",
" \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n",
" },\n",
" {\n",
" \"input\": \"How many employees are there\",\n",
" \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n",
" },\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can create a few-shot prompt with them like so:"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate\n",
"\n",
"example_prompt = PromptTemplate.from_template(\"User input: {input}\\nSQL query: {query}\")\n",
"prompt = FewShotPromptTemplate(\n",
" examples=examples[:5],\n",
" example_prompt=example_prompt,\n",
" prefix=\"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nBelow are a number of examples of questions and their corresponding SQL queries.\",\n",
" suffix=\"User input: {input}\\nSQL query: \",\n",
" input_variables=[\"input\", \"top_k\", \"table_info\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.\n",
"\n",
"Here is the relevant table info: foo\n",
"\n",
"Below are a number of examples of questions and their corresponding SQL queries.\n",
"\n",
"User input: List all artists.\n",
"SQL query: SELECT * FROM Artist;\n",
"\n",
"User input: Find all albums for the artist 'AC/DC'.\n",
"SQL query: SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\n",
"\n",
"User input: List all tracks in the 'Rock' genre.\n",
"SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n",
"\n",
"User input: Find the total duration of all tracks.\n",
"SQL query: SELECT SUM(Milliseconds) FROM Track;\n",
"\n",
"User input: List all customers from Canada.\n",
"SQL query: SELECT * FROM Customer WHERE Country = 'Canada';\n",
"\n",
"User input: How many artists are there?\n",
"SQL query: \n"
]
}
],
"source": [
"print(prompt.format(input=\"How many artists are there?\", top_k=3, table_info=\"foo\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dynamic few-shot examples\n",
"\n",
"If we have enough examples, we may want to only include the most relevant ones in the prompt, either because they don't fit in the model's context window or because the long tail of examples distracts the model. And specifically, given any input we want to include the examples most relevant to that input.\n",
"\n",
"We can do just this using an ExampleSelector. In this case we'll use a [SemanticSimilarityExampleSelector](https://api.python.langchain.com/en/latest/example_selectors/langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.html), which will store the examples in the vector database of our choosing. At runtime it will perform a similarity search between the input and our examples, and return the most semantically similar ones: "
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores import FAISS\n",
"from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"example_selector = SemanticSimilarityExampleSelector.from_examples(\n",
" examples,\n",
" OpenAIEmbeddings(),\n",
" FAISS,\n",
" k=5,\n",
" input_keys=[\"input\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'input': 'List all artists.', 'query': 'SELECT * FROM Artist;'},\n",
" {'input': 'How many employees are there',\n",
" 'query': 'SELECT COUNT(*) FROM \"Employee\"'},\n",
" {'input': 'How many tracks are there in the album with ID 5?',\n",
" 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'},\n",
" {'input': 'Which albums are from the year 2000?',\n",
" 'query': \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\"},\n",
" {'input': \"List all tracks in the 'Rock' genre.\",\n",
" 'query': \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\"}]"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example_selector.select_examples({\"input\": \"how many artists are there?\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use it, we can pass the ExampleSelector directly in to our FewShotPromptTemplate:"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"prompt = FewShotPromptTemplate(\n",
" example_selector=example_selector,\n",
" example_prompt=example_prompt,\n",
" prefix=\"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\\n\\nHere is the relevant table info: {table_info}\\n\\nBelow are a number of examples of questions and their corresponding SQL queries.\",\n",
" suffix=\"User input: {input}\\nSQL query: \",\n",
" input_variables=[\"input\", \"top_k\", \"table_info\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.\n",
"\n",
"Here is the relevant table info: foo\n",
"\n",
"Below are a number of examples of questions and their corresponding SQL queries.\n",
"\n",
"User input: List all artists.\n",
"SQL query: SELECT * FROM Artist;\n",
"\n",
"User input: How many employees are there\n",
"SQL query: SELECT COUNT(*) FROM \"Employee\"\n",
"\n",
"User input: How many tracks are there in the album with ID 5?\n",
"SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n",
"\n",
"User input: Which albums are from the year 2000?\n",
"SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n",
"\n",
"User input: List all tracks in the 'Rock' genre.\n",
"SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n",
"\n",
"User input: how many artists are there?\n",
"SQL query: \n"
]
}
],
"source": [
"print(prompt.format(input=\"how many artists are there?\", top_k=3, table_info=\"foo\"))"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'SELECT COUNT(*) FROM Artist;'"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain = create_sql_query_chain(llm, db, prompt)\n",
"chain.invoke({\"question\": \"how many artists are there?\"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,389 @@
{
"cells": [
{
"cell_type": "raw",
"id": "494149c1-9a1a-4b75-8982-6bb19cc5e14e",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 3\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "4da7ae91-4973-4e97-a570-fa24024ec65d",
"metadata": {},
"source": [
"# Query validation\n",
"\n",
"Perhaps the most error-prone part of any SQL chain or agent is writing valid and safe SQL queries. In this guide we'll go over some strategies for validating our queries and handling invalid queries.\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d40d5bc-3647-4b5d-808a-db470d40fe7a",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
]
},
{
"cell_type": "markdown",
"id": "c998536a-b1ff-46e7-ac51-dc6deb55d22b",
"metadata": {},
"source": [
"We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71f46270-e1c6-45b4-b36e-ea2e9f860eba",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
"\n",
"# Uncomment the below to use LangSmith. Not required.\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"id": "a0a2151b-cecf-4559-92a1-ca48824fed18",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8cedc936-5268-4bfa-b838-bdcc1ee9573c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
]
},
{
"data": {
"text/plain": [
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
]
},
{
"cell_type": "markdown",
"id": "2d203315-fab7-4621-80da-41e9bf82d803",
"metadata": {},
"source": [
"## Query checker\n",
"\n",
"Perhaps the simplest strategy is to ask the model itself to check the original query for common mistakes. Suppose we have the following SQL query chain:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec66bb76-b1ad-48ad-a7d4-b518e9421b86",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import create_sql_query_chain\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
"chain = create_sql_query_chain(llm, db)"
]
},
{
"cell_type": "markdown",
"id": "da01023d-cc05-43e3-a38d-ed9d56d3ad15",
"metadata": {},
"source": [
"And we want to validate its outputs. We can do so by extending the chain with a second prompt and model call:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "16686750-d8ee-4c60-8d67-b28281cb6164",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"system = \"\"\"Double check the user's {dialect} query for common mistakes, including:\n",
"- Using NOT IN with NULL values\n",
"- Using UNION when UNION ALL should have been used\n",
"- Using BETWEEN for exclusive ranges\n",
"- Data type mismatch in predicates\n",
"- Properly quoting identifiers\n",
"- Using the correct number of arguments for functions\n",
"- Casting to the correct data type\n",
"- Using the proper columns for joins\n",
"\n",
"If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n",
"\n",
"Output the final SQL query only.\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", system), (\"human\", \"{query}\")]\n",
").partial(dialect=db.dialect)\n",
"validation_chain = prompt | llm | StrOutputParser()\n",
"\n",
"full_chain = {\"query\": chain} | validation_chain"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "3a910260-205d-4f4e-afc6-9477572dc947",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"SELECT AVG(Invoice.Total) AS AverageInvoice\\nFROM Invoice\\nJOIN Customer ON Invoice.CustomerId = Customer.CustomerId\\nWHERE Customer.Country = 'USA'\\nAND Customer.Fax IS NULL\\nAND Invoice.InvoiceDate >= '2003-01-01'\\nAND Invoice.InvoiceDate < '2010-01-01'\""
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = full_chain.invoke(\n",
" {\n",
" \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n",
" }\n",
")\n",
"query"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "d01d78b5-89a0-4c12-b743-707ebe64ba86",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(6.632999999999998,)]'"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "6e133526-26bd-49da-9cfa-7adc0e59fd72",
"metadata": {},
"source": [
"The obvious downside of this approach is that we need to make two model calls instead of one to generate our query. To get around this we can try to perform the query generation and query check in a single model invocation:"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "7af0030a-549e-4e69-9298-3d0a038c2fdd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m System Message \u001b[0m================================\n",
"\n",
"You are a \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m expert. Given an input question, creat a syntactically correct \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query to run.\n",
"Unless the user specifies in the question a specific number of examples to obtain, query for at most \u001b[33;1m\u001b[1;3m{top_k}\u001b[0m results using the LIMIT clause as per \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m. You can order the results to return the most informative data in the database.\n",
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
"\n",
"Only use the following tables:\n",
"\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n",
"\n",
"Write an initial draft of the query. Then double check the \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query for common mistakes, including:\n",
"- Using NOT IN with NULL values\n",
"- Using UNION when UNION ALL should have been used\n",
"- Using BETWEEN for exclusive ranges\n",
"- Data type mismatch in predicates\n",
"- Properly quoting identifiers\n",
"- Using the correct number of arguments for functions\n",
"- Casting to the correct data type\n",
"- Using the proper columns for joins\n",
"\n",
"Use format:\n",
"\n",
"First draft: <<FIRST_DRAFT_QUERY>>\n",
"Final answer: <<FINAL_ANSWER_QUERY>>\n",
"\n",
"\n",
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"\u001b[33;1m\u001b[1;3m{input}\u001b[0m\n"
]
}
],
"source": [
"system = \"\"\"You are a {dialect} expert. Given an input question, creat a syntactically correct {dialect} query to run.\n",
"Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.\n",
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
"\n",
"Only use the following tables:\n",
"{table_info}\n",
"\n",
"Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:\n",
"- Using NOT IN with NULL values\n",
"- Using UNION when UNION ALL should have been used\n",
"- Using BETWEEN for exclusive ranges\n",
"- Data type mismatch in predicates\n",
"- Properly quoting identifiers\n",
"- Using the correct number of arguments for functions\n",
"- Casting to the correct data type\n",
"- Using the proper columns for joins\n",
"\n",
"Use format:\n",
"\n",
"First draft: <<FIRST_DRAFT_QUERY>>\n",
"Final answer: <<FINAL_ANSWER_QUERY>>\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", \"{input}\")]).partial(dialect=db.dialect)\n",
"\n",
"def parse_final_answer(output: str) -> str:\n",
" return output.split(\"Final answer: \")[1]\n",
" \n",
"chain = create_sql_query_chain(llm, db, prompt=prompt) | parse_final_answer\n",
"prompt.pretty_print()"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "806e27a2-e511-45ea-a4ed-8ce8fa6e1d58",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"\\nSELECT AVG(i.Total) AS AverageInvoice\\nFROM Invoice i\\nJOIN Customer c ON i.CustomerId = c.CustomerId\\nWHERE c.Country = 'USA' AND c.Fax IS NULL AND i.InvoiceDate >= date('2003-01-01') AND i.InvoiceDate < date('2010-01-01')\""
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = chain.invoke(\n",
" {\n",
" \"question\": \"What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010\"\n",
" }\n",
")\n",
"query"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "70fff2fa-1f86-4f83-9fd2-e87a5234d329",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(6.632999999999998,)]'"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(query)"
]
},
{
"cell_type": "markdown",
"id": "fc8af115-7c23-421a-8fd7-29bf1b6687a4",
"metadata": {},
"source": [
"## Human-in-the-loop\n",
"\n",
"In some cases our data is sensitive enough that we never want to execute a SQL query without a human approving it first. Head to the [Tool use: Human-in-the-loop](/docs/use_cases/tool_use/human_in_the_loop) page to learn how to add a human-in-the-loop to any tool, chain or agent.\n",
"\n",
"## Error handling\n",
"\n",
"At some point, the model will make a mistake and craft an invalid SQL query. Or an issue will arise with our database. Or the model API will go down. We'll want to add some error handling behavior to our chains and agents so that we fail gracefully in these situations, and perhaps even automatically recover. To learn about error handling with tools, head to the [Tool use: Error handling](/docs/use_cases/tool_use/tool_error_handling) page."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,603 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_position: 0\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quickstart\n",
"\n",
"In this guide we'll go over the basic ways to create a Q&A chain and agent over a SQL database. These systems will allow us to ask a question about the data in a SQL database and get back a natural language answer. The main difference between the two is that our agent can query the database in a loop as many time as it needs to answer the question.\n",
"\n",
"## ⚠️ Security note ⚠️\n",
"\n",
"Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n",
"\n",
"\n",
"## Architecture\n",
"\n",
"At a high-level, the steps of any SQL chain and agent are:\n",
"\n",
"1. **Convert question to SQL query**: Model converts user input to a SQL query.\n",
"2. **Execute SQL query**: Execute the SQL query.\n",
"3. **Answer the question**: Model responds to user input using the query results.\n",
"\n",
"\n",
"![sql_usecase.png](../../../static/img/sql_usecase.png)\n",
"\n",
"## Setup\n",
"\n",
"First, get required packages and set environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We default to OpenAI models in this guide."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
"\n",
"# Uncomment the below to use LangSmith. Not required.\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
"\n",
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`\n",
"* Run `sqlite3 Chinook.db`\n",
"* Run `.read Chinook_Sqlite.sql`\n",
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
"\n",
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
]
},
{
"data": {
"text/plain": [
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.utilities import SQLDatabase\n",
"\n",
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
"print(db.dialect)\n",
"print(db.get_usable_table_names())\n",
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! We've got a SQL database that we can query. Now let's try hooking it up to an LLM.\n",
"\n",
"## Chain\n",
"\n",
"Let's create a simple chain that takes a question, turns it into a SQL query, executes the query, and uses the result to answer the original question.\n",
"\n",
"### Convert question to SQL query\n",
"\n",
"The first step in a SQL chain or agent is to take the user input and convert it to a SQL query. LangChain comes with a built-in chain for this: [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'SELECT COUNT(*) FROM Employee'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chains import create_sql_query_chain\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
"chain = create_sql_query_chain(llm, db)\n",
"response = chain.invoke({\"question\": \"How many employees are there\"})\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can execute the query to make sure it's valid:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(8,)]'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.run(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can look at the [LangSmith trace](https://smith.langchain.com/public/c8fa52ea-be46-4829-bde2-52894970b830/r) to get a better understanding of what this chain is doing. We can also inpect the chain directly for its prompts. Looking at the prompt (below), we can see that it is:\n",
"\n",
"* Dialect-specific. In this case it references SQLite explicitly.\n",
"* Has definitions for all the available tables.\n",
"* Has three examples rows for each table.\n",
"\n",
"This technique is inspired by papers like [this](https://arxiv.org/pdf/2204.00498.pdf), which suggest showing examples rows and being explicit about tables improves performance. We can also in"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n",
"Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n",
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
"\n",
"Use the following format:\n",
"\n",
"Question: Question here\n",
"SQLQuery: SQL Query to run\n",
"SQLResult: Result of the SQLQuery\n",
"Answer: Final answer here\n",
"\n",
"Only use the following tables:\n",
"\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n",
"\n",
"Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n"
]
}
],
"source": [
"chain.get_prompts()[0].pretty_print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Execute SQL query\n",
"\n",
"Now that we've generated a SQL query, we'll want to execute it. **This is the most dangerous part of creating a SQL chain.** Consider carefully if it is OK to run automated queries over your data. Minimize the database connection permissions as much as possible. Consider adding a human approval step to you chains before query execution (see below).\n",
"\n",
"We can use the `QuerySQLDatabaseTool` to easily add query execution to our chain:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[(8,)]'"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n",
"\n",
"execute_query = QuerySQLDataBaseTool(db=db)\n",
"write_query = create_sql_query_chain(llm, db)\n",
"chain = write_query | execute_query\n",
"chain.invoke({\"question\": \"How many employees are there\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Answer the question\n",
"\n",
"Now that we've got a way to automatically generate and execute queries, we just need to combine the original question and SQL query result to generate a final answer. We can do this by passing question and result to the LLM once more:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'There are 8 employees.'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from operator import itemgetter\n",
"\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import PromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"answer_prompt = PromptTemplate.from_template(\n",
" \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n",
"\n",
"Question: {question}\n",
"SQL Query: {query}\n",
"SQL Result: {result}\n",
"Answer: \"\"\"\n",
")\n",
"\n",
"answer = answer_prompt | llm | StrOutputParser()\n",
"chain = (\n",
" RunnablePassthrough.assign(query=write_query).assign(\n",
" result=itemgetter(\"query\") | execute_query\n",
" )\n",
" | answer\n",
")\n",
"\n",
"chain.invoke({\"question\": \"How many employees are there\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Next steps\n",
"\n",
"For more complex query-generation, we may want to create few-shot prompts or add query-checking steps. For advanced techniques like this and more check out:\n",
"\n",
"* [Prompting strategies](/docs/use_cases/sql/prompting): Advanced prompt engineering techniques.\n",
"* [Query checking](/docs/use_cases/sql/query_checking): Add query validation and error handling.\n",
"* [Large databses](/docs/use_cases/sql/large_db): Techniques for working with large databases."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agents\n",
"\n",
"LangChain has an SQL Agent which provides a more flexible way of interacting with SQL databases. The main advantages of using the SQL Agent are:\n",
"\n",
"- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
"- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
"- It can answer questions that require multiple dependent queries.\n",
"\n",
"To initialize the agent, we use `create_sql_agent` function. This agent contains the `SQLDatabaseToolkit` which contains tools to: \n",
"\n",
"* Create and execute queries\n",
"* Check query syntax\n",
"* Retrieve table descriptions\n",
"* ... and more\n",
"\n",
"### Initializing agent"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.agent_toolkits import create_sql_agent\n",
"\n",
"agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_list_tables` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_schema` with `Invoice,Customer`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"Customer\" (\n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
"\t\"Company\" NVARCHAR(80), \n",
"\t\"Address\" NVARCHAR(70), \n",
"\t\"City\" NVARCHAR(40), \n",
"\t\"State\" NVARCHAR(40), \n",
"\t\"Country\" NVARCHAR(40), \n",
"\t\"PostalCode\" NVARCHAR(10), \n",
"\t\"Phone\" NVARCHAR(24), \n",
"\t\"Fax\" NVARCHAR(24), \n",
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
"\t\"SupportRepId\" INTEGER, \n",
"\tPRIMARY KEY (\"CustomerId\"), \n",
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Customer table:\n",
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
"*/\n",
"\n",
"\n",
"CREATE TABLE \"Invoice\" (\n",
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
"\t\"CustomerId\" INTEGER NOT NULL, \n",
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
"\t\"BillingAddress\" NVARCHAR(70), \n",
"\t\"BillingCity\" NVARCHAR(40), \n",
"\t\"BillingState\" NVARCHAR(40), \n",
"\t\"BillingCountry\" NVARCHAR(40), \n",
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
"\tPRIMARY KEY (\"InvoiceId\"), \n",
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from Invoice table:\n",
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
"1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`\n",
"responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n",
"\n",
"Here is the SQL query:\n",
"\n",
"```sql\n",
"SELECT c.Country, SUM(i.Total) AS TotalSales\n",
"FROM Invoice i\n",
"JOIN Customer c ON i.CustomerId = c.CustomerId\n",
"GROUP BY c.Country\n",
"ORDER BY TotalSales DESC\n",
"LIMIT 10;\n",
"```\n",
"\n",
"Now, I will execute this query to get the total sales per country.\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n",
"\n",
"1. USA: $523.06\n",
"2. Canada: $303.96\n",
"3. France: $195.10\n",
"4. Brazil: $190.10\n",
"5. Germany: $156.48\n",
"6. United Kingdom: $112.86\n",
"7. Czech Republic: $90.24\n",
"8. Portugal: $77.24\n",
"9. India: $75.26\n",
"10. Chile: $46.62\n",
"\n",
"To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
" 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\n",
" {\n",
" \"input\": \"List the total sales per country. Which country's customers spent the most?\"\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_list_tables` with `{}`\n",
"\n",
"\n",
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `sql_db_schema` with `PlaylistTrack`\n",
"\n",
"\n",
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"\n",
"/*\n",
"3 rows from PlaylistTrack table:\n",
"PlaylistId\tTrackId\n",
"1\t3402\n",
"1\t3389\n",
"1\t3390\n",
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n",
"\n",
"Here is the schema of the `PlaylistTrack` table:\n",
"\n",
"```\n",
"CREATE TABLE \"PlaylistTrack\" (\n",
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
"\t\"TrackId\" INTEGER NOT NULL, \n",
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
")\n",
"```\n",
"\n",
"The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n",
"\n",
"Here are three sample rows from the `PlaylistTrack` table:\n",
"\n",
"```\n",
"PlaylistId TrackId\n",
"1 3402\n",
"1 3389\n",
"1 3390\n",
"```\n",
"\n",
"Please let me know if there is anything else I can help with.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': 'Describe the playlisttrack table',\n",
" 'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \\n\\nHere is the schema of the `PlaylistTrack` table:\\n\\n```\\nCREATE TABLE \"PlaylistTrack\" (\\n\\t\"PlaylistId\" INTEGER NOT NULL, \\n\\t\"TrackId\" INTEGER NOT NULL, \\n\\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \\n\\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \\n\\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\\n)\\n```\\n\\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\\n\\nHere are three sample rows from the `PlaylistTrack` table:\\n\\n```\\nPlaylistId TrackId\\n1 3402\\n1 3389\\n1 3390\\n```\\n\\nPlease let me know if there is anything else I can help with.'}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke({\"input\": \"Describe the playlisttrack table\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Next steps\n",
"\n",
"For more on how to use and customize agents head to the [Agents](/docs/use_cases/sql/agents) page."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "poetry-venv",
"language": "python",
"name": "poetry-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -1,5 +1,9 @@
{
"redirects": [
{
"source": "/docs/use_cases/qa_structured/sql",
"destination": "/docs/use_cases/sql/"
},
{
"source": "/docs/contributing/packages",
"destination": "/docs/packages"
@ -294,7 +298,7 @@
},
{
"source": "/docs/use_cases/qa_structured/integrations/sqlite",
"destination": "/cookbook"
"destination": "/docs/use_cases/sql/"
},
{
"source": "/docs/use_cases/more/graph/tot",
@ -1418,7 +1422,7 @@
},
{
"source": "/docs/integrations/tools/sqlite",
"destination": "/docs/use_cases/qa_structured/sql"
"destination": "/docs/use_cases/sql"
},
{
"source": "/en/latest/modules/callbacks/filecallbackhandler.html",
@ -3506,11 +3510,7 @@
},
{
"source": "/en/latest/use_cases/tabular.html",
"destination": "/docs/use_cases/qa_structured"
},
{
"source": "/docs/use_cases/sql(/?)",
"destination": "/docs/use_cases/qa_structured/sql"
"destination": "/docs/use_cases/sql"
},
{
"source": "/en/latest/youtube.html",

View File

@ -1,11 +1,11 @@
"""SQL agent."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
@ -15,22 +15,29 @@ from langchain_core.prompts.chat import (
from langchain_community.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools import BaseTool
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
)
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_community.utilities.sql_database import SQLDatabase
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: Optional[AgentType] = None,
toolkit: Optional[SQLDatabaseToolkit] = None,
agent_type: Optional[Union[AgentType, Literal["openai-tools"]]] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
@ -41,62 +48,165 @@ def create_sql_agent(
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
*,
db: Optional[SQLDatabase] = None,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.chains.llm import LLMChain
"""Construct a SQL agent from an LLM and toolkit or database.
Args:
llm: Language model to use for the agent.
toolkit: SQLDatabaseToolkit for the agent to use. Must provide exactly one of
'toolkit' or 'db'. Specify 'toolkit' if you want to use a different model
for the agent and the toolkit.
agent_type: One of "openai-tools", "openai-functions", or
"zero-shot-react-description". Defaults to "zero-shot-react-description".
"openai-tools" is recommended over "openai-functions".
callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
instead to pass constructor callbacks to AgentExecutor.
prefix: Prompt prefix string. Must contain variables "top_k" and "dialect".
suffix: Prompt suffix string. Default depends on agent type.
format_instructions: Formatting instructions to pass to
ZeroShotAgent.create_prompt() when 'agent_type' is
"zero-shot-react-description". Otherwise ignored.
input_variables: DEPRECATED. Input variables to explicitly specify as part of
ZeroShotAgent.create_prompt() when 'agent_type' is
"zero-shot-react-description". Otherwise ignored.
top_k: Number of rows to query for by default.
max_iterations: Passed to AgentExecutor init.
max_execution_time: Passed to AgentExecutor init.
early_stopping_method: Passed to AgentExecutor init.
verbose: AgentExecutor verbosity.
agent_executor_kwargs: Arbitrary additional AgentExecutor args.
extra_tools: Additional tools to give to agent on top of the ones that come with
SQLDatabaseToolkit.
db: SQLDatabase from which to create a SQLDatabaseToolkit. Toolkit is created
using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
input_variables} are mutually exclusive.
**kwargs: DEPRECATED. Not used, kept for backwards compatibility.
Returns:
An AgentExecutor with the specified agent_type agent.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
""" # noqa: E501
from langchain.agents import (
create_openai_functions_agent,
create_openai_tools_agent,
create_react_agent,
)
from langchain.agents.agent import (
AgentExecutor,
RunnableAgent,
RunnableMultiActionAgent,
)
from langchain.agents.agent_types import AgentType
if toolkit is None and db is None:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received neither."
)
if toolkit and db:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received both."
)
if kwargs:
warnings.warn(
f"Received additional kwargs {kwargs} which are no longer supported."
)
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db)
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
tools = toolkit.get_tools() + list(extra_tools)
if prompt is None:
prefix = prefix or SQL_PREFIX
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent
else:
if "top_k" in prompt.input_variables:
prompt = prompt.partial(top_k=str(top_k))
if "dialect" in prompt.input_variables:
prompt = prompt.partial(dialect=toolkit.dialect)
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt_params = (
{"format_instructions": format_instructions}
if format_instructions is not None
else {}
if prompt is None:
from langchain.agents.mrkl import prompt as react_prompt
format_instructions = (
format_instructions or react_prompt.FORMAT_INSTRUCTIONS
)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
input_variables=input_variables,
**prompt_params,
template = "\n\n".join(
[
react_prompt.PREFIX,
"{tools}",
format_instructions,
react_prompt.SUFFIX,
]
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
prompt = PromptTemplate.from_template(template)
agent = RunnableAgent(
runnable=create_react_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
if prompt is None:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
)
elif agent_type == "openai-tools":
if prompt is None:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableMultiActionAgent(
runnable=create_openai_tools_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
return AgentExecutor(
name="SQL Agent Executor",
agent=agent,
tools=tools,
callback_manager=callback_manager,

View File

@ -69,3 +69,7 @@ class SQLDatabaseToolkit(BaseToolkit):
list_sql_database_tool,
query_sql_checker_tool,
]
def get_context(self) -> dict:
"""Return db context that you may want in agent prompt."""
return self.db.get_context()

View File

@ -1,10 +1,10 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
import warnings
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
import sqlalchemy
from langchain_core._api import deprecated
from langchain_core.utils import get_from_env
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy.engine import Engine
@ -272,11 +272,9 @@ class SQLDatabase:
return sorted(self._include_tables)
return sorted(self._all_tables - self._ignore_tables)
@deprecated("0.0.1", alternative="get_usable_table_name", removal="0.2.0")
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
@property
@ -487,3 +485,9 @@ class SQLDatabase:
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
def get_context(self) -> Dict[str, Any]:
"""Return db context that you may want in agent prompt."""
table_names = list(self.get_usable_table_names())
table_info = self.get_table_info_no_throw()
return {"table_info": table_info, "table_names": ", ".join(table_names)}

View File

@ -74,6 +74,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
vectorstore_cls: Type[VectorStore],
k: int = 4,
input_keys: Optional[List[str]] = None,
*,
example_keys: Optional[List[str]] = None,
vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any,
) -> SemanticSimilarityExampleSelector:
"""Create k-shot example selector using example list and embeddings.
@ -102,7 +105,13 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
vectorstore = vectorstore_cls.from_texts(
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
)
return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)
return cls(
vectorstore=vectorstore,
k=k,
input_keys=input_keys,
example_keys=example_keys,
vectorstore_kwargs=vectorstore_kwargs,
)
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):

View File

@ -343,8 +343,8 @@ class RunnableAgent(BaseSingleActionAgent):
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action."""
_input_keys: List[str] = []
"""Input keys."""
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
class Config:
"""Configuration for this pydantic object."""
@ -354,16 +354,11 @@ class RunnableAgent(BaseSingleActionAgent):
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return []
return self.return_keys_arg
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
Returns:
List of input keys.
"""
return self._input_keys
return self.input_keys_arg
def plan(
self,
@ -439,8 +434,8 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
"""Runnable to call to get agent actions."""
_input_keys: List[str] = []
"""Input keys."""
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
class Config:
"""Configuration for this pydantic object."""
@ -450,7 +445,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
@property
def return_values(self) -> List[str]:
"""Return values of the agent."""
return []
return self.return_keys_arg
@property
def input_keys(self) -> List[str]:
@ -459,7 +454,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
Returns:
List of input keys.
"""
return self._input_keys
return self.input_keys_arg
def plan(
self,

View File

@ -83,9 +83,9 @@ class ZeroShotAgent(Agent):
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
if input_variables:
return PromptTemplate(template=template, input_variables=input_variables)
return PromptTemplate.from_template(template)
@classmethod
def from_llm_and_tools(

View File

@ -131,7 +131,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)
@ -178,7 +178,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks = config.get("callbacks")
tags = config.get("tags")
metadata = config.get("metadata")
run_name = config.get("run_name")
run_name = config.get("run_name") or self.get_name()
include_run_info = kwargs.get("include_run_info", False)
return_only_outputs = kwargs.get("return_only_outputs", False)

View File

@ -1,10 +1,10 @@
from typing import List, Optional, TypedDict, Union
from typing import Any, Dict, List, Optional, TypedDict, Union
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnableParallel
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
@ -31,7 +31,7 @@ def create_sql_query_chain(
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
"""Create a chain that generates SQL queries.
*Security Note*: This chain generates SQL queries for the given database.
@ -50,34 +50,93 @@ def create_sql_query_chain(
See https://python.langchain.com/docs/security for more information.
Args:
llm: The language model to use
db: The SQLDatabase to generate the query for
llm: The language model to use.
db: The SQLDatabase to generate the query for.
prompt: The prompt to use. If none is provided, will choose one
based on dialect. Defaults to None.
based on dialect. Defaults to None. See Prompt section below for more.
k: The number of results per select statement to return. Defaults to 5.
Returns:
A chain that takes in a question and generates a SQL query that answers
that question.
"""
Example:
.. code-block:: python
# pip install -U langchain langchain-community langchain-openai
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
Prompt:
If no prompt is provided, a default prompt is selected based on the SQLDatabase dialect. If one is provided, it must support input variables:
* input: The user question plus suffix "\nSQLQuery: " is passed here.
* top_k: The number of results per select statement (the `k` argument to
this function) is passed in here.
* table_info: Table definitions and sample rows are passed in here. If the
user specifies "table_names_to_use" when invoking chain, only those
will be included. Otherwise, all tables are included.
* dialect (optional): If dialect input variable is in prompt, the db
dialect will be passed in here.
Here's an example prompt:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}.
Question: {input}'''
prompt = PromptTemplate.from_template(template)
""" # noqa: E501
if prompt is not None:
prompt_to_use = prompt
elif db.dialect in SQL_PROMPTS:
prompt_to_use = SQL_PROMPTS[db.dialect]
else:
prompt_to_use = PROMPT
if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables):
raise ValueError(
f"Prompt must have input variables: 'input', 'top_k', "
f"'table_info'. Received prompt with input variables: "
f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
)
if "dialect" in prompt_to_use.input_variables:
prompt_to_use = prompt_to_use.partial(dialect=db.dialect)
inputs = {
"input": lambda x: x["question"] + "\nSQLQuery: ",
"top_k": lambda _: k,
"table_info": lambda x: db.get_table_info(
table_names=x.get("table_names_to_use")
),
}
if "dialect" in prompt_to_use.input_variables:
inputs["dialect"] = lambda _: (db.dialect, prompt_to_use)
return (
RunnableParallel(inputs)
| prompt_to_use
RunnablePassthrough.assign(**inputs) # type: ignore
| (
lambda x: {
k: v
for k, v in x.items()
if k not in ("question", "table_names_to_use")
}
)
| prompt_to_use.partial(top_k=str(k))
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
| _strip

View File

@ -1,3 +1,7 @@
from functools import partial
from typing import Optional
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
@ -10,8 +14,37 @@ class RetrieverInput(BaseModel):
query: str = Field(description="query to look up in retriever")
def _get_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
) -> str:
docs = retriever.get_relevant_documents(query)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
)
async def _aget_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
) -> str:
docs = await retriever.aget_relevant_documents(query)
return document_separator.join(
format_document(doc, document_prompt) for doc in docs
)
def create_retriever_tool(
retriever: BaseRetriever, name: str, description: str
retriever: BaseRetriever,
name: str,
description: str,
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
) -> Tool:
"""Create a tool to do retrieval of documents.
@ -25,10 +58,23 @@ def create_retriever_tool(
Returns:
Tool class to pass to an agent
"""
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
_get_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
return Tool(
name=name,
description=description,
func=retriever.get_relevant_documents,
coroutine=retriever.aget_relevant_documents,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
)