mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
1172 lines
44 KiB
Plaintext
1172 lines
44 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Build a Question/Answering system over SQL data\n",
|
|
"\n",
|
|
"Enabling a LLM system to query structured data can be qualitatively different from unstructured text data. Whereas in the latter it is common to generate text that can be searched against a vector database, the approach for structured data is often for the LLM to write and execute queries in a DSL, such as SQL. In this guide we'll go over the basic ways to create a Q&A system over tabular data in databases. We will cover implementations using both chains and agents. These systems will allow us to ask a question about the data in a database and get back a natural language answer. The main difference between the two is that our agent can query the database in a loop as many times as it needs to answer the question.\n",
|
|
"\n",
|
|
"## ⚠️ Security note ⚠️\n",
|
|
"\n",
|
|
"Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n",
|
|
"\n",
|
|
"\n",
|
|
"## Architecture\n",
|
|
"\n",
|
|
"At a high-level, the steps of these systems are:\n",
|
|
"\n",
|
|
"1. **Convert question to DSL query**: Model converts user input to a SQL query.\n",
|
|
"2. **Execute SQL query**: Execute the query.\n",
|
|
"3. **Answer the question**: Model responds to user input using the query results.\n",
|
|
"\n",
|
|
"Note that querying data in CSVs can follow a similar approach. See our [how-to guide](/docs/use_cases/sql/csv/) on question-answering over CSV data for more detail.\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"## Setup\n",
|
|
"\n",
|
|
"First, get required packages and set environment variables:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install --upgrade --quiet langchain langchain-community langchain-openai"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We will use an OpenAI model in this guide."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import getpass\n",
|
|
"import os\n",
|
|
"\n",
|
|
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
|
|
"\n",
|
|
"# Uncomment the below to use LangSmith. Not required.\n",
|
|
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()\n",
|
|
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:\n",
|
|
"\n",
|
|
"* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook.sql`\n",
|
|
"* Run `sqlite3 Chinook.db`\n",
|
|
"* Run `.read Chinook.sql`\n",
|
|
"* Test `SELECT * FROM Artist LIMIT 10;`\n",
|
|
"\n",
|
|
"Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"sqlite\n",
|
|
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
|
|
]
|
|
},
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain_community.utilities import SQLDatabase\n",
|
|
"\n",
|
|
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
|
|
"print(db.dialect)\n",
|
|
"print(db.get_usable_table_names())\n",
|
|
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Great! We've got a SQL database that we can query. Now let's try hooking it up to an LLM.\n",
|
|
"\n",
|
|
"## Chains {#chains}\n",
|
|
"\n",
|
|
"Chains (i.e., compositions of LangChain [Runnables](/docs/expression_language/)) support applications whose steps are predictable. We can create a simple chain that takes a question and does the following:\n",
|
|
"- convert the question into a SQL query;\n",
|
|
"- execute the query;\n",
|
|
"- use the result to answer the original question.\n",
|
|
"\n",
|
|
"There are scenarios not supported by this arrangement. For example, this system will execute a SQL query for any user input-- even \"hello\". Importantly, as we'll see below, some questions require more than one query to answer. We will address these scenarios in the Agents section.\n",
|
|
"\n",
|
|
"### Convert question to SQL query\n",
|
|
"\n",
|
|
"The first step in a SQL chain or agent is to take the user input and convert it to a SQL query. LangChain comes with a built-in chain for this: [create_sql_query_chain](https://api.python.langchain.com/en/latest/chains/langchain.chains.sql_database.query.create_sql_query_chain.html)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"```{=mdx}\n",
|
|
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
|
"\n",
|
|
"<ChatModelTabs customVarName=\"llm\" />\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# | output: false\n",
|
|
"# | echo: false\n",
|
|
"\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'SELECT COUNT(\"EmployeeId\") AS \"TotalEmployees\" FROM \"Employee\"\\nLIMIT 1;'"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain.chains import create_sql_query_chain\n",
|
|
"\n",
|
|
"chain = create_sql_query_chain(llm, db)\n",
|
|
"response = chain.invoke({\"question\": \"How many employees are there\"})\n",
|
|
"response"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can execute the query to make sure it's valid:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'[(8,)]'"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"db.run(response)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can look at the [LangSmith trace](https://smith.langchain.com/public/c8fa52ea-be46-4829-bde2-52894970b830/r) to get a better understanding of what this chain is doing. We can also inspect the chain directly for its prompts. Looking at the prompt (below), we can see that it is:\n",
|
|
"\n",
|
|
"* Dialect-specific. In this case it references SQLite explicitly.\n",
|
|
"* Has definitions for all the available tables.\n",
|
|
"* Has three examples rows for each table.\n",
|
|
"\n",
|
|
"This technique is inspired by papers like [this](https://arxiv.org/pdf/2204.00498.pdf), which suggest showing examples rows and being explicit about tables improves performance. We can also inspect the full prompt like so:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\n",
|
|
"Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\n",
|
|
"Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (\") to denote them as delimited identifiers.\n",
|
|
"Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n",
|
|
"Pay attention to use date('now') function to get the current date, if the question involves \"today\".\n",
|
|
"\n",
|
|
"Use the following format:\n",
|
|
"\n",
|
|
"Question: Question here\n",
|
|
"SQLQuery: SQL Query to run\n",
|
|
"SQLResult: Result of the SQLQuery\n",
|
|
"Answer: Final answer here\n",
|
|
"\n",
|
|
"Only use the following tables:\n",
|
|
"\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n",
|
|
"\n",
|
|
"Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"chain.get_prompts()[0].pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Execute SQL query\n",
|
|
"\n",
|
|
"Now that we've generated a SQL query, we'll want to execute it. **This is the most dangerous part of creating a SQL chain.** Consider carefully if it is OK to run automated queries over your data. Minimize the database connection permissions as much as possible. Consider adding a human approval step to you chains before query execution (see below).\n",
|
|
"\n",
|
|
"We can use the `QuerySQLDatabaseTool` to easily add query execution to our chain:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'[(8,)]'"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n",
|
|
"\n",
|
|
"execute_query = QuerySQLDataBaseTool(db=db)\n",
|
|
"write_query = create_sql_query_chain(llm, db)\n",
|
|
"chain = write_query | execute_query\n",
|
|
"chain.invoke({\"question\": \"How many employees are there\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Answer the question\n",
|
|
"\n",
|
|
"Now that we've got a way to automatically generate and execute queries, we just need to combine the original question and SQL query result to generate a final answer. We can do this by passing question and result to the LLM once more:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'There are a total of 8 employees.'"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from operator import itemgetter\n",
|
|
"\n",
|
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
|
"from langchain_core.prompts import PromptTemplate\n",
|
|
"from langchain_core.runnables import RunnablePassthrough\n",
|
|
"\n",
|
|
"answer_prompt = PromptTemplate.from_template(\n",
|
|
" \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n",
|
|
"\n",
|
|
"Question: {question}\n",
|
|
"SQL Query: {query}\n",
|
|
"SQL Result: {result}\n",
|
|
"Answer: \"\"\"\n",
|
|
")\n",
|
|
"\n",
|
|
"chain = (\n",
|
|
" RunnablePassthrough.assign(query=write_query).assign(\n",
|
|
" result=itemgetter(\"query\") | execute_query\n",
|
|
" )\n",
|
|
" | answer_prompt\n",
|
|
" | llm\n",
|
|
" | StrOutputParser()\n",
|
|
")\n",
|
|
"\n",
|
|
"chain.invoke({\"question\": \"How many employees are there\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's review what is happening in the above LCEL. Suppose this chain is invoked.\n",
|
|
"- After the first `RunnablePassthrough.assign`, we have a runnable with two elements: \n",
|
|
" `{\"question\": question, \"query\": write_query.invoke(question)}` \n",
|
|
" Where `write_query` will generate a SQL query in service of answering the question.\n",
|
|
"- After the second `RunnablePassthrough.assign`, we have add a third element `\"result\"` that contains `execute_query.invoke(query)`, where `query` was computed in the previous step.\n",
|
|
"- These three inputs are formatted into the prompt and passed into the LLM.\n",
|
|
"- The `StrOutputParser()` plucks out the string content of the output message.\n",
|
|
"\n",
|
|
"Note that we are composing LLMs, tools, prompts, and other chains together, but because each implements the Runnable interface, their inputs and outputs can be tied together in a reasonable way."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Next steps\n",
|
|
"\n",
|
|
"For more complex query-generation, we may want to create few-shot prompts or add query-checking steps. For advanced techniques like this and more check out:\n",
|
|
"\n",
|
|
"* [Prompting strategies](/docs/use_cases/sql/prompting): Advanced prompt engineering techniques.\n",
|
|
"* [Query checking](/docs/use_cases/sql/query_checking): Add query validation and error handling.\n",
|
|
"* [Large databses](/docs/use_cases/sql/large_db): Techniques for working with large databases."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Agents {#agents}\n",
|
|
"\n",
|
|
"LangChain has a SQL Agent which provides a more flexible way of interacting with SQL Databases than a chain. The main advantages of using the SQL Agent are:\n",
|
|
"\n",
|
|
"- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
|
|
"- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
|
|
"- It can query the database as many times as needed to answer the user question.\n",
|
|
"- It will save tokens by only retrieving the schema from relevant tables.\n",
|
|
"\n",
|
|
"To initialize the agent we'll use the [create_sql_agent](https://api.python.langchain.com/en/latest/agent_toolkits/langchain_community.agent_toolkits.sql.base.create_sql_agent.html) constructor. This agent uses the `SQLDatabaseToolkit` which contains tools to: \n",
|
|
"\n",
|
|
"* Create and execute queries\n",
|
|
"* Check query syntax\n",
|
|
"* Retrieve table descriptions\n",
|
|
"* ... and more\n",
|
|
"\n",
|
|
"### Initializing agent\n",
|
|
"\n",
|
|
"LangChain provides a built-in `create_sql_agent` constructor that accepts an LLM and database as input.\n",
|
|
"\n",
|
|
"Note that we specify `agent_type`, which determines the principle by which agents execute tools:\n",
|
|
"- Tool-calling agents are compatible with LLMs that support function or tool calling features.\n",
|
|
"- ReAct agents rely on prompting the LLM to follow a certain natural language pattern (e.g., \"Thought: \", \"Action: \", etc.). They are potentially less reliable, but do not require specialized tool-calling capabilities from an LLM.\n",
|
|
"\n",
|
|
"See our [agent types](/docs/modules/agents/agent_types/) documentation for more detail.\n",
|
|
"\n",
|
|
"Below, we create a tool-calling agent."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.agent_toolkits import create_sql_agent\n",
|
|
"\n",
|
|
"agent_executor = create_sql_agent(llm, db=db, agent_type=\"tool-calling\", verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Consider how the agent responds to the below question:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new SQL Agent Executor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_list_tables` with `{}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"Customer\" (\n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
|
|
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
|
|
"\t\"Company\" NVARCHAR(80), \n",
|
|
"\t\"Address\" NVARCHAR(70), \n",
|
|
"\t\"City\" NVARCHAR(40), \n",
|
|
"\t\"State\" NVARCHAR(40), \n",
|
|
"\t\"Country\" NVARCHAR(40), \n",
|
|
"\t\"PostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Phone\" NVARCHAR(24), \n",
|
|
"\t\"Fax\" NVARCHAR(24), \n",
|
|
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
|
|
"\t\"SupportRepId\" INTEGER, \n",
|
|
"\tPRIMARY KEY (\"CustomerId\"), \n",
|
|
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Customer table:\n",
|
|
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
|
|
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
|
|
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
|
|
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Invoice\" (\n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
|
|
"\t\"BillingAddress\" NVARCHAR(70), \n",
|
|
"\t\"BillingCity\" NVARCHAR(40), \n",
|
|
"\t\"BillingState\" NVARCHAR(40), \n",
|
|
"\t\"BillingCountry\" NVARCHAR(40), \n",
|
|
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceId\"), \n",
|
|
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Invoice table:\n",
|
|
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
|
|
"1\t2\t2021-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
|
|
"2\t4\t2021-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
|
|
"3\t8\t2021-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"InvoiceLine\" (\n",
|
|
"\t\"InvoiceLineId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
|
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\t\"Quantity\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceLineId\"), \n",
|
|
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from InvoiceLine table:\n",
|
|
"InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n",
|
|
"1\t1\t2\t0.99\t1\n",
|
|
"2\t1\t4\t0.99\t1\n",
|
|
"3\t2\t6\t0.99\t1\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_query` with `{'query': 'SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1'}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003)]\u001b[0m\u001b[32;1m\u001b[1;3mCustomers from the USA spent the most, with a total amount spent of $523.06.\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': \"Which country's customers spent the most?\",\n",
|
|
" 'output': 'Customers from the USA spent the most, with a total amount spent of $523.06.'}"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.invoke({\"input\": \"Which country's customers spent the most?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that the agent executes multiple queries until it has the information it needs:\n",
|
|
"1. List available tables;\n",
|
|
"2. Retrieves the schema for three tables;\n",
|
|
"3. Queries multiple of the tables via a join operation.\n",
|
|
"\n",
|
|
"The agent is then able to use the result of the final query to generate an answer to the original question.\n",
|
|
"\n",
|
|
"The agent can similarly handle qualitative questions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new SQL Agent Executor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_list_tables` with `{}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `{'table_names': 'PlaylistTrack'}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"PlaylistTrack\" (\n",
|
|
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
|
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from PlaylistTrack table:\n",
|
|
"PlaylistId\tTrackId\n",
|
|
"1\t3402\n",
|
|
"1\t3389\n",
|
|
"1\t3390\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe `PlaylistTrack` table has the following columns:\n",
|
|
"- PlaylistId (INTEGER, NOT NULL)\n",
|
|
"- TrackId (INTEGER, NOT NULL)\n",
|
|
"\n",
|
|
"It has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on TrackId referencing the Track table's TrackId column, and on PlaylistId referencing the Playlist table's PlaylistId column.\n",
|
|
"\n",
|
|
"Here are 3 sample rows from the `PlaylistTrack` table:\n",
|
|
"1. PlaylistId: 1, TrackId: 3402\n",
|
|
"2. PlaylistId: 1, TrackId: 3389\n",
|
|
"3. PlaylistId: 1, TrackId: 3390\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': 'Describe the playlisttrack table',\n",
|
|
" 'output': \"The `PlaylistTrack` table has the following columns:\\n- PlaylistId (INTEGER, NOT NULL)\\n- TrackId (INTEGER, NOT NULL)\\n\\nIt has a composite primary key on the columns PlaylistId and TrackId. Additionally, there are foreign key constraints on TrackId referencing the Track table's TrackId column, and on PlaylistId referencing the Playlist table's PlaylistId column.\\n\\nHere are 3 sample rows from the `PlaylistTrack` table:\\n1. PlaylistId: 1, TrackId: 3402\\n2. PlaylistId: 1, TrackId: 3389\\n3. PlaylistId: 1, TrackId: 3390\"}"
|
|
]
|
|
},
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.invoke(\"Describe the playlisttrack table\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Using a dynamic few-shot prompt\n",
|
|
"\n",
|
|
"To optimize agent performance, we can provide a custom prompt with domain-specific knowledge. In this case we'll create a few shot prompt with an example selector, that will dynamically build the few shot prompt based on the user input. This will help the model make better queries by inserting relevant queries in the prompt that the model can use as reference.\n",
|
|
"\n",
|
|
"First we need some user input <> SQL query examples:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"examples = [\n",
|
|
" {\"input\": \"List all artists.\", \"query\": \"SELECT * FROM Artist;\"},\n",
|
|
" {\n",
|
|
" \"input\": \"Find all albums for the artist 'AC/DC'.\",\n",
|
|
" \"query\": \"SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"List all tracks in the 'Rock' genre.\",\n",
|
|
" \"query\": \"SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"Find the total duration of all tracks.\",\n",
|
|
" \"query\": \"SELECT SUM(Milliseconds) FROM Track;\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"List all customers from Canada.\",\n",
|
|
" \"query\": \"SELECT * FROM Customer WHERE Country = 'Canada';\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"How many tracks are there in the album with ID 5?\",\n",
|
|
" \"query\": \"SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"Find the total number of invoices.\",\n",
|
|
" \"query\": \"SELECT COUNT(*) FROM Invoice;\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"List all tracks that are longer than 5 minutes.\",\n",
|
|
" \"query\": \"SELECT * FROM Track WHERE Milliseconds > 300000;\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"Who are the top 5 customers by total purchase?\",\n",
|
|
" \"query\": \"SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"Which albums are from the year 2000?\",\n",
|
|
" \"query\": \"SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"input\": \"How many employees are there\",\n",
|
|
" \"query\": 'SELECT COUNT(*) FROM \"Employee\"',\n",
|
|
" },\n",
|
|
"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Although stuffing all of these examples in a prompt is an option, we can also intelligently select examples that are most relevant to a given query. For this we can use example selectors. An example selector will take the actual user input and select some number of examples to add to our few-shot prompt. We'll use a SemanticSimilarityExampleSelector, which will perform a semantic search using the embeddings and vector store we configure to find the examples most similar to our input:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.vectorstores import FAISS\n",
|
|
"from langchain_core.example_selectors import SemanticSimilarityExampleSelector\n",
|
|
"from langchain_openai import OpenAIEmbeddings\n",
|
|
"\n",
|
|
"example_selector = SemanticSimilarityExampleSelector.from_examples(\n",
|
|
" examples,\n",
|
|
" OpenAIEmbeddings(),\n",
|
|
" FAISS,\n",
|
|
" k=5,\n",
|
|
" input_keys=[\"input\"],\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's try it out:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[{'input': 'How many employees are there',\n",
|
|
" 'query': 'SELECT COUNT(*) FROM \"Employee\"'},\n",
|
|
" {'input': 'Find the total number of invoices.',\n",
|
|
" 'query': 'SELECT COUNT(*) FROM Invoice;'},\n",
|
|
" {'input': 'Who are the top 5 customers by total purchase?',\n",
|
|
" 'query': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;'},\n",
|
|
" {'input': 'List all customers from Canada.',\n",
|
|
" 'query': \"SELECT * FROM Customer WHERE Country = 'Canada';\"},\n",
|
|
" {'input': 'How many tracks are there in the album with ID 5?',\n",
|
|
" 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'}]"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"example_selector.select_examples({\"input\": \"How many employees are there\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that this particular selector is simply identifying examples whose inputs are semantically similar to a given input.\n",
|
|
"\n",
|
|
"LangChain's `FewShotPromptTemplate` functions as a typical prompt template, but will retrieve relevant few-shot examples using the example selector and format them into the prompt. Let's instantiate one:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_core.prompts import (\n",
|
|
" ChatPromptTemplate,\n",
|
|
" FewShotPromptTemplate,\n",
|
|
" MessagesPlaceholder,\n",
|
|
" PromptTemplate,\n",
|
|
" SystemMessagePromptTemplate,\n",
|
|
")\n",
|
|
"\n",
|
|
"system_prefix = \"\"\"You are an agent designed to interact with a SQL database.\n",
|
|
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
|
|
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
|
|
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
|
|
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
|
|
"You have access to tools for interacting with the database.\n",
|
|
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
|
|
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
|
|
"\n",
|
|
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
|
|
"\n",
|
|
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
|
|
"\n",
|
|
"Here are some examples of user inputs and their corresponding SQL queries:\"\"\"\n",
|
|
"\n",
|
|
"few_shot_prompt = FewShotPromptTemplate(\n",
|
|
" example_selector=example_selector,\n",
|
|
" example_prompt=PromptTemplate.from_template(\n",
|
|
" \"User input: {input}\\nSQL query: {query}\"\n",
|
|
" ),\n",
|
|
" input_variables=[\"input\", \"dialect\", \"top_k\"],\n",
|
|
" prefix=system_prefix,\n",
|
|
" suffix=\"\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Since our underlying agent is a [tool-calling agent](/docs/modules/agents/agent_types/tool_calling), our full prompt should be a chat prompt with a human message template and an agent_scratchpad `MessagesPlaceholder`. The few-shot prompt will be used for our system message:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"full_prompt = ChatPromptTemplate.from_messages(\n",
|
|
" [\n",
|
|
" SystemMessagePromptTemplate(prompt=few_shot_prompt),\n",
|
|
" (\"human\", \"{input}\"),\n",
|
|
" MessagesPlaceholder(\"agent_scratchpad\"),\n",
|
|
" ]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can invoke the prompt template to inspect how an individual query would be formatted:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"System: You are an agent designed to interact with a SQL database.\n",
|
|
"Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n",
|
|
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n",
|
|
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
|
|
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
|
|
"You have access to tools for interacting with the database.\n",
|
|
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
|
|
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
|
|
"\n",
|
|
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
|
|
"\n",
|
|
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\n",
|
|
"\n",
|
|
"Here are some examples of user inputs and their corresponding SQL queries:\n",
|
|
"\n",
|
|
"User input: List all artists.\n",
|
|
"SQL query: SELECT * FROM Artist;\n",
|
|
"\n",
|
|
"User input: How many employees are there\n",
|
|
"SQL query: SELECT COUNT(*) FROM \"Employee\"\n",
|
|
"\n",
|
|
"User input: How many tracks are there in the album with ID 5?\n",
|
|
"SQL query: SELECT COUNT(*) FROM Track WHERE AlbumId = 5;\n",
|
|
"\n",
|
|
"User input: List all tracks in the 'Rock' genre.\n",
|
|
"SQL query: SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');\n",
|
|
"\n",
|
|
"User input: Which albums are from the year 2000?\n",
|
|
"SQL query: SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';\n",
|
|
"Human: How many arists are there\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"prompt_val = full_prompt.invoke(\n",
|
|
" {\n",
|
|
" \"input\": \"How many arists are there\",\n",
|
|
" \"top_k\": 5,\n",
|
|
" \"dialect\": \"SQLite\",\n",
|
|
" \"agent_scratchpad\": [],\n",
|
|
" }\n",
|
|
")\n",
|
|
"print(prompt_val.to_string())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can then build the agent in the same way. Calls to the agent will pass the query through the example selector in order to build the few-shot examples into the system prompt."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"agent = create_sql_agent(\n",
|
|
" llm=llm,\n",
|
|
" db=db,\n",
|
|
" prompt=full_prompt,\n",
|
|
" verbose=True,\n",
|
|
" agent_type=\"tool-calling\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new SQL Agent Executor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist;'}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[36;1m\u001b[1;3m[(275,)]\u001b[0m\u001b[32;1m\u001b[1;3mThere are 275 artists in the database.\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': 'How many artists are there?',\n",
|
|
" 'output': 'There are 275 artists in the database.'}"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent.invoke({\"input\": \"How many artists are there?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Dealing with high-cardinality columns\n",
|
|
"\n",
|
|
"In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
|
|
"\n",
|
|
"We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n",
|
|
"\n",
|
|
"First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['Great Recordings of the Century - Mahler: Das Lied von der Erde',\n",
|
|
" 'Holst: The Planets, Op. & Vaughan Williams: Fantasies',\n",
|
|
" 'Chemical Wedding',\n",
|
|
" 'By The Way',\n",
|
|
" 'The Singles']"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import ast\n",
|
|
"import re\n",
|
|
"\n",
|
|
"\n",
|
|
"def query_as_list(db, query):\n",
|
|
" res = db.run(query)\n",
|
|
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
|
|
" res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n",
|
|
" return list(set(res))\n",
|
|
"\n",
|
|
"\n",
|
|
"artists = query_as_list(db, \"SELECT Name FROM Artist\")\n",
|
|
"albums = query_as_list(db, \"SELECT Title FROM Album\")\n",
|
|
"albums[:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Using this function, we can create a **retriever tool** that the agent can execute at its discretion."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents.agent_toolkits import create_retriever_tool\n",
|
|
"\n",
|
|
"vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())\n",
|
|
"retriever = vector_db.as_retriever(search_kwargs={\"k\": 5})\n",
|
|
"description = \"\"\"Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \\\n",
|
|
"valid proper nouns. Use the noun most similar to the search.\"\"\"\n",
|
|
"retriever_tool = create_retriever_tool(\n",
|
|
" retriever,\n",
|
|
" name=\"search_proper_nouns\",\n",
|
|
" description=description,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's try it out:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Alice In Chains\n",
|
|
"\n",
|
|
"Alanis Morissette\n",
|
|
"\n",
|
|
"Pearl Jam\n",
|
|
"\n",
|
|
"Pearl Jam\n",
|
|
"\n",
|
|
"Audioslave\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(retriever_tool.invoke(\"Alice Chains\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This way, if the agent determines it needs to write a filter based on an artist along the lines of \"Alice Chains\", it can first use the retriever tool to observe relevant values of a column.\n",
|
|
"\n",
|
|
"Putting this together:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system = \"\"\"You are an agent designed to interact with a SQL database.\n",
|
|
"Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n",
|
|
"Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.\n",
|
|
"You can order the results by a relevant column to return the most interesting examples in the database.\n",
|
|
"Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n",
|
|
"You have access to tools for interacting with the database.\n",
|
|
"Only use the given tools. Only use the information returned by the tools to construct your final answer.\n",
|
|
"You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.\n",
|
|
"\n",
|
|
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.\n",
|
|
"\n",
|
|
"If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the \"search_proper_nouns\" tool! \n",
|
|
"\n",
|
|
"You have access to the following tables: {table_names}\n",
|
|
"\n",
|
|
"If the question does not seem related to the database, just return \"I don't know\" as the answer.\"\"\"\n",
|
|
"\n",
|
|
"prompt = ChatPromptTemplate.from_messages(\n",
|
|
" [(\"system\", system), (\"human\", \"{input}\"), MessagesPlaceholder(\"agent_scratchpad\")]\n",
|
|
")\n",
|
|
"agent = create_sql_agent(\n",
|
|
" llm=llm,\n",
|
|
" db=db,\n",
|
|
" extra_tools=[retriever_tool],\n",
|
|
" prompt=prompt,\n",
|
|
" agent_type=\"tool-calling\",\n",
|
|
" verbose=True,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new SQL Agent Executor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `search_proper_nouns` with `{'query': 'alis in chain'}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[36;1m\u001b[1;3mAlice In Chains\n",
|
|
"\n",
|
|
"Aisha Duo\n",
|
|
"\n",
|
|
"Xis\n",
|
|
"\n",
|
|
"Da Lama Ao Caos\n",
|
|
"\n",
|
|
"A-Sides\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_query` with `{'query': \"SELECT COUNT(*) AS album_count FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')\"}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[36;1m\u001b[1;3m[(1,)]\u001b[0m\u001b[32;1m\u001b[1;3mAlice In Chains has 1 album.\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': 'How many albums does alis in chain have?',\n",
|
|
" 'output': 'Alice In Chains has 1 album.'}"
|
|
]
|
|
},
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent.invoke({\"input\": \"How many albums does alis in chain have?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"As we can see, the agent used the `search_proper_nouns` tool in order to check how to correctly query the database for this specific artist."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Next steps\n",
|
|
"\n",
|
|
"Under the hood, `create_sql_agent` is just passing in SQL tools to more generic agent constructors. To learn more about the built-in generic agent types as well as how to build custom agents, head to the [Agents Modules](/docs/modules/agents/).\n",
|
|
"\n",
|
|
"The built-in `AgentExecutor` runs a simple Agent action -> Tool call -> Agent action... loop. To build more complex agent runtimes, head to the [LangGraph section](/docs/langgraph)."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|