mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-13 18:36:19 +00:00
705 lines
26 KiB
Plaintext
705 lines
26 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0e499e90-7a6d-4fab-8aab-31a4df417601",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SQL Database\n",
|
|
"\n",
|
|
"This notebook showcases an agent designed to interact with a `SQL` databases. \n",
|
|
"\n",
|
|
"It is designed to answer more general questions about a database, as well as recover from errors.\n",
|
|
"\n",
|
|
"Note that, as this agent is in active development, all answers might not be correct. \n",
|
|
"\n",
|
|
"Additionally, it is not guaranteed that the agent won't perform DML statements on your database given certain questions. \n",
|
|
"\n",
|
|
"Be careful running it on sensitive data!|\n",
|
|
"\n",
|
|
"This uses the example `Chinook` database. \n",
|
|
"\n",
|
|
"To set it up follow [these instructions](https://database.guide/2-sample-databases-sqlite/) and place the .db file in a notebooks folder at the root of this repository."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "39ee5811-7f7d-47a9-941a-83dcf4b74239",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\n",
|
|
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%pip install --upgrade --quiet langchain langchain-community langchain-experimental"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "804533b1-2f16-497b-821b-c82d67fcf7b6",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Initialize our database"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "42bd5a41-672a-4a53-b70a-2f0c0555758c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.utilities.sql_database import SQLDatabase\n",
|
|
"\n",
|
|
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fabfc8b0-2354-4f4d-b334-cd8ae3a12b0a",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Quickstart\n",
|
|
"\n",
|
|
"Following the [SQL use case docs](/docs/use_cases/sql/agents), we can use the `create_sql_agent` helper."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "c0504b5c-72a8-4407-9056-944615f1d480",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.agent_toolkits import create_sql_agent\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
|
|
"agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "75f2f603-6004-4363-a781-a6e352124ffa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_list_tables` with `{}`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `Invoice,Customer`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"Customer\" (\n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
|
|
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
|
|
"\t\"Company\" NVARCHAR(80), \n",
|
|
"\t\"Address\" NVARCHAR(70), \n",
|
|
"\t\"City\" NVARCHAR(40), \n",
|
|
"\t\"State\" NVARCHAR(40), \n",
|
|
"\t\"Country\" NVARCHAR(40), \n",
|
|
"\t\"PostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Phone\" NVARCHAR(24), \n",
|
|
"\t\"Fax\" NVARCHAR(24), \n",
|
|
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
|
|
"\t\"SupportRepId\" INTEGER, \n",
|
|
"\tPRIMARY KEY (\"CustomerId\"), \n",
|
|
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Customer table:\n",
|
|
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
|
|
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
|
|
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
|
|
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Invoice\" (\n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
|
|
"\t\"BillingAddress\" NVARCHAR(70), \n",
|
|
"\t\"BillingCity\" NVARCHAR(40), \n",
|
|
"\t\"BillingState\" NVARCHAR(40), \n",
|
|
"\t\"BillingCountry\" NVARCHAR(40), \n",
|
|
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceId\"), \n",
|
|
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Invoice table:\n",
|
|
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
|
|
"1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
|
|
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
|
|
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`\n",
|
|
"responded: To list the total sales per country, I can query the \"Invoice\" and \"Customer\" tables. I will join these tables on the \"CustomerId\" column and group the results by the \"BillingCountry\" column. Then, I will calculate the sum of the \"Total\" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.\n",
|
|
"\n",
|
|
"Here is the SQL query:\n",
|
|
"\n",
|
|
"```sql\n",
|
|
"SELECT c.Country, SUM(i.Total) AS TotalSales\n",
|
|
"FROM Invoice i\n",
|
|
"JOIN Customer c ON i.CustomerId = c.CustomerId\n",
|
|
"GROUP BY c.Country\n",
|
|
"ORDER BY TotalSales DESC\n",
|
|
"LIMIT 10;\n",
|
|
"```\n",
|
|
"\n",
|
|
"Now, I will execute this query to get the total sales per country.\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]\u001b[0m\u001b[32;1m\u001b[1;3mThe total sales per country are as follows:\n",
|
|
"\n",
|
|
"1. USA: $523.06\n",
|
|
"2. Canada: $303.96\n",
|
|
"3. France: $195.10\n",
|
|
"4. Brazil: $190.10\n",
|
|
"5. Germany: $156.48\n",
|
|
"6. United Kingdom: $112.86\n",
|
|
"7. Czech Republic: $90.24\n",
|
|
"8. Portugal: $77.24\n",
|
|
"9. India: $75.26\n",
|
|
"10. Chile: $46.62\n",
|
|
"\n",
|
|
"To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
|
|
" 'output': 'The total sales per country are as follows:\\n\\n1. USA: $523.06\\n2. Canada: $303.96\\n3. France: $195.10\\n4. Brazil: $190.10\\n5. Germany: $156.48\\n6. United Kingdom: $112.86\\n7. Czech Republic: $90.24\\n8. Portugal: $77.24\\n9. India: $75.26\\n10. Chile: $46.62\\n\\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.invoke(\n",
|
|
" \"List the total sales per country. Which country's customers spent the most?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ec927ac6-9b2a-4e8a-9a6e-3e429191875c",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"## Toolkit\n",
|
|
"\n",
|
|
"We can look at what runs under the hood with the `create_sql_agent` helper above.\n",
|
|
"\n",
|
|
"We can also highlight how the toolkit is used with the agent specifically."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "056ddd8a-9835-41ef-84a3-6b680c22248d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"\n",
|
|
"toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(temperature=0))\n",
|
|
"context = toolkit.get_context()\n",
|
|
"tools = toolkit.get_tools()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "08a653b1-02f2-45f3-b357-acff5a6b3869",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Use SQLDatabaseToolkit within an Agent\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "3f673e63-49bf-4112-b74f-d7baa91ed81f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.agent_toolkits.sql.prompt import SQL_FUNCTIONS_SUFFIX\n",
|
|
"from langchain_core.messages import AIMessage, SystemMessage\n",
|
|
"from langchain_core.prompts.chat import (\n",
|
|
" ChatPromptTemplate,\n",
|
|
" HumanMessagePromptTemplate,\n",
|
|
" MessagesPlaceholder,\n",
|
|
")\n",
|
|
"\n",
|
|
"messages = [\n",
|
|
" HumanMessagePromptTemplate.from_template(\"{input}\"),\n",
|
|
" AIMessage(content=SQL_FUNCTIONS_SUFFIX),\n",
|
|
" MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
|
|
"]\n",
|
|
"\n",
|
|
"prompt = ChatPromptTemplate.from_messages(messages)\n",
|
|
"prompt = prompt.partial(**context)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "83f57336-1757-4a54-bf79-591c0a97e6a4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents import create_openai_tools_agent\n",
|
|
"from langchain.agents.agent import AgentExecutor\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
|
|
"\n",
|
|
"agent = create_openai_tools_agent(llm, tools, prompt)\n",
|
|
"\n",
|
|
"agent_executor = AgentExecutor(\n",
|
|
" agent=agent,\n",
|
|
" tools=toolkit.get_tools(),\n",
|
|
" verbose=True,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f77805fb-c05e-408c-9b69-6d2a774c208e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Disclaimer ⚠️\n",
|
|
"\n",
|
|
"The query chain may generate insert/update/delete queries. When this is not expected, use a custom prompt or create a SQL users without write permissions.\n",
|
|
"\n",
|
|
"The final user might overload your SQL database by asking a simple question such as \"run the biggest query possible\". The generated query might look like:\n",
|
|
"\n",
|
|
"```sql\n",
|
|
"SELECT * FROM \"public\".\"users\"\n",
|
|
" JOIN \"public\".\"user_permissions\" ON \"public\".\"users\".id = \"public\".\"user_permissions\".user_id\n",
|
|
" JOIN \"public\".\"projects\" ON \"public\".\"users\".id = \"public\".\"projects\".user_id\n",
|
|
" JOIN \"public\".\"events\" ON \"public\".\"projects\".id = \"public\".\"events\".project_id;\n",
|
|
"```\n",
|
|
"\n",
|
|
"For a transactional SQL database, if one of the table above contains millions of rows, the query might cause trouble to other applications using the same database.\n",
|
|
"\n",
|
|
"Most datawarehouse oriented databases support user-level quota, for limiting resource usage.\n",
|
|
"\n",
|
|
"## Example: describing a table"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "f740fbef-d635-4121-856d-33593ae6bfc0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_list_tables` with ``\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `Playlist`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"Playlist\" (\n",
|
|
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
|
|
"\t\"Name\" NVARCHAR(120), \n",
|
|
"\tPRIMARY KEY (\"PlaylistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Playlist table:\n",
|
|
"PlaylistId\tName\n",
|
|
"1\tMusic\n",
|
|
"2\tMovies\n",
|
|
"3\tTV Shows\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe schema of the \"Playlist\" table is as follows:\n",
|
|
"\n",
|
|
"- PlaylistId: INTEGER (Primary Key)\n",
|
|
"- Name: NVARCHAR(120)\n",
|
|
"\n",
|
|
"Here are three sample rows from the \"Playlist\" table:\n",
|
|
"\n",
|
|
"| PlaylistId | Name |\n",
|
|
"|------------|-----------|\n",
|
|
"| 1 | Music |\n",
|
|
"| 2 | Movies |\n",
|
|
"| 3 | TV Shows |\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': 'Describe the schema of the playlist table',\n",
|
|
" 'output': 'The schema of the \"Playlist\" table is as follows:\\n\\n- PlaylistId: INTEGER (Primary Key)\\n- Name: NVARCHAR(120)\\n\\nHere are three sample rows from the \"Playlist\" table:\\n\\n| PlaylistId | Name |\\n|------------|-----------|\\n| 1 | Music |\\n| 2 | Movies |\\n| 3 | TV Shows |'}"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.invoke({\"input\": \"Describe the schema of the playlist table\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9abcfe8e-1868-42a4-8345-ad2d9b44c681",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example: describing a table, recovering from an error\n",
|
|
"\n",
|
|
"In this example, the agent tries to search for a table, `playlists`, that doesn't exist, but finds the next best result"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "d78ddb6e-f499-4e90-9604-fe22a9c3592e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_list_tables` with ``\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `Playlist`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"Playlist\" (\n",
|
|
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
|
|
"\t\"Name\" NVARCHAR(120), \n",
|
|
"\tPRIMARY KEY (\"PlaylistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Playlist table:\n",
|
|
"PlaylistId\tName\n",
|
|
"1\tMusic\n",
|
|
"2\tMovies\n",
|
|
"3\tTV Shows\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3mThe \"Playlists\" table has the following schema:\n",
|
|
"\n",
|
|
"- PlaylistId: INTEGER (Primary Key)\n",
|
|
"- Name: NVARCHAR(120)\n",
|
|
"\n",
|
|
"Here are three sample rows from the \"Playlists\" table:\n",
|
|
"\n",
|
|
"| PlaylistId | Name |\n",
|
|
"|------------|-----------|\n",
|
|
"| 1 | Music |\n",
|
|
"| 2 | Movies |\n",
|
|
"| 3 | TV Shows |\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': 'Describe the playlists table',\n",
|
|
" 'output': 'The \"Playlists\" table has the following schema:\\n\\n- PlaylistId: INTEGER (Primary Key)\\n- Name: NVARCHAR(120)\\n\\nHere are three sample rows from the \"Playlists\" table:\\n\\n| PlaylistId | Name |\\n|------------|-----------|\\n| 1 | Music |\\n| 2 | Movies |\\n| 3 | TV Shows |'}"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.invoke({\"input\": \"Describe the playlists table\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6fbc26af-97e4-4a21-82aa-48bdc992da26",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example: running queries"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "5457af9a-037d-4dba-9316-0abfe841f114",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_list_tables` with ``\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `Invoice, Customer`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"Customer\" (\n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
|
|
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
|
|
"\t\"Company\" NVARCHAR(80), \n",
|
|
"\t\"Address\" NVARCHAR(70), \n",
|
|
"\t\"City\" NVARCHAR(40), \n",
|
|
"\t\"State\" NVARCHAR(40), \n",
|
|
"\t\"Country\" NVARCHAR(40), \n",
|
|
"\t\"PostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Phone\" NVARCHAR(24), \n",
|
|
"\t\"Fax\" NVARCHAR(24), \n",
|
|
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
|
|
"\t\"SupportRepId\" INTEGER, \n",
|
|
"\tPRIMARY KEY (\"CustomerId\"), \n",
|
|
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Customer table:\n",
|
|
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
|
|
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
|
|
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
|
|
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Invoice\" (\n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
|
|
"\t\"BillingAddress\" NVARCHAR(70), \n",
|
|
"\t\"BillingCity\" NVARCHAR(40), \n",
|
|
"\t\"BillingState\" NVARCHAR(40), \n",
|
|
"\t\"BillingCountry\" NVARCHAR(40), \n",
|
|
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceId\"), \n",
|
|
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Invoice table:\n",
|
|
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
|
|
"1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
|
|
"2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
|
|
"3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 1;`\n",
|
|
"responded: Based on the schema of the `Customer` and `Invoice` tables, we can join these two tables on the `CustomerId` column to get the total sales per customer. Then, we can group the results by the `Country` column to get the total sales per country. Finally, we can sort the results in descending order of total sales and select the country with the highest total sales.\n",
|
|
"\n",
|
|
"Here is the SQL query to achieve this:\n",
|
|
"\n",
|
|
"```sql\n",
|
|
"SELECT c.Country, SUM(i.Total) AS TotalSales\n",
|
|
"FROM Customer c\n",
|
|
"JOIN Invoice i ON c.CustomerId = i.CustomerId\n",
|
|
"GROUP BY c.Country\n",
|
|
"ORDER BY TotalSales DESC\n",
|
|
"LIMIT 1;\n",
|
|
"```\n",
|
|
"\n",
|
|
"Let me execute this query to find out which country's customers spent the most.\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[36;1m\u001b[1;3m[('USA', 523.0600000000003)]\u001b[0m\u001b[32;1m\u001b[1;3mThe country whose customers spent the most is the USA, with a total sales amount of $523.06.\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': \"List the total sales per country. Which country's customers spent the most?\",\n",
|
|
" 'output': 'The country whose customers spent the most is the USA, with a total sales amount of $523.06.'}"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.invoke(\n",
|
|
" {\n",
|
|
" \"input\": \"List the total sales per country. Which country's customers spent the most?\"\n",
|
|
" }\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7c7503b5-d9d9-4faa-b064-29fcdb5ff213",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Recovering from an error\n",
|
|
"\n",
|
|
"In this example, the agent is able to recover from an error after initially trying to access an attribute (`Track.ArtistId`) which doesn't exist."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "4d8d8933-c1e5-4ba9-8143-1059606c4adb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_list_tables` with ``\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[38;5;200m\u001b[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `InvoiceLine`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"InvoiceLine\" (\n",
|
|
"\t\"InvoiceLineId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
|
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\t\"Quantity\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceLineId\"), \n",
|
|
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from InvoiceLine table:\n",
|
|
"InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n",
|
|
"1\t1\t2\t0.99\t1\n",
|
|
"2\t1\t4\t0.99\t1\n",
|
|
"3\t2\t6\t0.99\t1\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_schema` with `Artist`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[33;1m\u001b[1;3m\n",
|
|
"CREATE TABLE \"Artist\" (\n",
|
|
"\t\"ArtistId\" INTEGER NOT NULL, \n",
|
|
"\t\"Name\" NVARCHAR(120), \n",
|
|
"\tPRIMARY KEY (\"ArtistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Artist table:\n",
|
|
"ArtistId\tName\n",
|
|
"1\tAC/DC\n",
|
|
"2\tAccept\n",
|
|
"3\tAerosmith\n",
|
|
"*/\u001b[0m\u001b[32;1m\u001b[1;3m\n",
|
|
"Invoking: `sql_db_query` with `SELECT Artist.Name, SUM(InvoiceLine.Quantity) AS TotalQuantity FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY TotalQuantity DESC LIMIT 3`\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[0m\u001b[36;1m\u001b[1;3m[('Iron Maiden', 140), ('U2', 107), ('Metallica', 91)]\u001b[0m\u001b[32;1m\u001b[1;3mThe top 3 best selling artists are:\n",
|
|
"\n",
|
|
"1. Iron Maiden - 140 total quantity sold\n",
|
|
"2. U2 - 107 total quantity sold\n",
|
|
"3. Metallica - 91 total quantity sold\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'input': 'Who are the top 3 best selling artists?',\n",
|
|
" 'output': 'The top 3 best selling artists are:\\n\\n1. Iron Maiden - 140 total quantity sold\\n2. U2 - 107 total quantity sold\\n3. Metallica - 91 total quantity sold'}"
|
|
]
|
|
},
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"agent_executor.invoke({\"input\": \"Who are the top 3 best selling artists?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0ca6724d-89ee-49eb-8cdb-5cf66d242a83",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|