mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-11 18:16:12 +00:00
621 lines
25 KiB
Plaintext
621 lines
25 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0e499e90-7a6d-4fab-8aab-31a4df417601",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SQLDatabase Toolkit\n",
|
|
"\n",
|
|
"This will help you getting started with the SQL Database [toolkit](/docs/concepts/tools/#toolkits). For detailed documentation of all `SQLDatabaseToolkit` features and configurations head to the [API reference](https://python.langchain.com/api_reference/community/agent_toolkits/langchain_community.agent_toolkits.sql.toolkit.SQLDatabaseToolkit.html).\n",
|
|
"\n",
|
|
"Tools within the `SQLDatabaseToolkit` are designed to interact with a `SQL` database. \n",
|
|
"\n",
|
|
"A common application is to enable agents to answer questions using data in a relational database, potentially in an iterative fashion (e.g., recovering from errors).\n",
|
|
"\n",
|
|
"**⚠️ Security note ⚠️**\n",
|
|
"\n",
|
|
"Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n",
|
|
"\n",
|
|
"## Setup\n",
|
|
"\n",
|
|
"If you want to get automated tracing from runs of individual tools, you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3de6e3be-1fd9-42a3-8564-8ca7dca11e1c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
|
|
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "31896b61-68d2-4b4d-be9d-b829eda327d1",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Installation\n",
|
|
"\n",
|
|
"This toolkit lives in the `langchain-community` package:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c4933e04-9120-4ccc-9ef7-369987823b0e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install --upgrade --quiet langchain-community"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6ad08dbe-1642-448c-b58d-153810024375",
|
|
"metadata": {},
|
|
"source": [
|
|
"For demonstration purposes, we will access a prompt in the LangChain [Hub](https://smith.langchain.com/hub). We will also require `langgraph` to demonstrate the use of the toolkit with an agent. This is not required to use the toolkit."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f3dead45-9908-497d-a5a3-bce30642e88f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install --upgrade --quiet langchainhub langgraph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "804533b1-2f16-497b-821b-c82d67fcf7b6",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Instantiation\n",
|
|
"\n",
|
|
"The `SQLDatabaseToolkit` toolkit requires:\n",
|
|
"\n",
|
|
"- a [SQLDatabase](https://python.langchain.com/api_reference/community/utilities/langchain_community.utilities.sql_database.SQLDatabase.html) object;\n",
|
|
"- a LLM or chat model (for instantiating the [QuerySQLCheckerTool](https://python.langchain.com/api_reference/community/tools/langchain_community.tools.sql_database.tool.QuerySQLCheckerTool.html) tool).\n",
|
|
"\n",
|
|
"Below, we instantiate the toolkit with these objects. Let's first create a database object.\n",
|
|
"\n",
|
|
"This guide uses the example `Chinook` database based on [these instructions](https://database.guide/2-sample-databases-sqlite/).\n",
|
|
"\n",
|
|
"Below we will use the `requests` library to pull the `.sql` file and create an in-memory SQLite database. Note that this approach is lightweight, but ephemeral and not thread-safe. If you'd prefer, you can follow the instructions to save the file locally as `Chinook.db` and instantiate the database via `db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "40d05f9b-5a8f-4307-8f8b-4153db0fdfa9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sqlite3\n",
|
|
"\n",
|
|
"import requests\n",
|
|
"from langchain_community.utilities.sql_database import SQLDatabase\n",
|
|
"from sqlalchemy import create_engine\n",
|
|
"from sqlalchemy.pool import StaticPool\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_engine_for_chinook_db():\n",
|
|
" \"\"\"Pull sql file, populate in-memory database, and create engine.\"\"\"\n",
|
|
" url = \"https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql\"\n",
|
|
" response = requests.get(url)\n",
|
|
" sql_script = response.text\n",
|
|
"\n",
|
|
" connection = sqlite3.connect(\":memory:\", check_same_thread=False)\n",
|
|
" connection.executescript(sql_script)\n",
|
|
" return create_engine(\n",
|
|
" \"sqlite://\",\n",
|
|
" creator=lambda: connection,\n",
|
|
" poolclass=StaticPool,\n",
|
|
" connect_args={\"check_same_thread\": False},\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
"engine = get_engine_for_chinook_db()\n",
|
|
"\n",
|
|
"db = SQLDatabase(engine)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2b9a6326-78fd-4c42-a1cb-4316619ac449",
|
|
"metadata": {},
|
|
"source": [
|
|
"We will also need a LLM or chat model:\n",
|
|
"\n",
|
|
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
|
"\n",
|
|
"<ChatModelTabs customVarName=\"llm\" />\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "cc6e6108-83d9-404f-8f31-474c2fbf5f6c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# | output: false\n",
|
|
"# | echo: false\n",
|
|
"\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(temperature=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "77925e72-4730-43c3-8726-d68cedf635f4",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can now instantiate the toolkit:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "42bd5a41-672a-4a53-b70a-2f0c0555758c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit\n",
|
|
"\n",
|
|
"toolkit = SQLDatabaseToolkit(db=db, llm=llm)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b2f882cf-4156-4a9f-a714-db97ec8ccc37",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Tools\n",
|
|
"\n",
|
|
"View available tools:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "a18c3e69-bee0-4f5d-813e-eeb540f41b98",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[QuerySQLDatabaseTool(description=\"Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.\", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),\n",
|
|
" InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),\n",
|
|
" ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>),\n",
|
|
" QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x103d5fa60>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x10742d720>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x10742f7f0>, root_client=<openai.OpenAI object at 0x103d5fac0>, root_async_client=<openai.AsyncOpenAI object at 0x10742d780>, temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\\n{query}\\nDouble check the {dialect} query above for common mistakes, including:\\n- Using NOT IN with NULL values\\n- Using UNION when UNION ALL should have been used\\n- Using BETWEEN for exclusive ranges\\n- Data type mismatch in predicates\\n- Properly quoting identifiers\\n- Using the correct number of arguments for functions\\n- Casting to the correct data type\\n- Using the proper columns for joins\\n\\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\\n\\nOutput the final SQL query only.\\n\\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x10742d720>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x10742f7f0>, root_client=<openai.OpenAI object at 0x103d5fac0>, root_async_client=<openai.AsyncOpenAI object at 0x10742d780>, temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"toolkit.get_tools()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4e3fd064-aa86-448d-8db3-3c55eaa5bc15",
|
|
"metadata": {},
|
|
"source": [
|
|
"You can use the individual tools directly:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7fa8d00c-750c-4803-9b66-057d12b26b06",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.tools.sql_database.tool import (\n",
|
|
" InfoSQLDatabaseTool,\n",
|
|
" ListSQLDatabaseTool,\n",
|
|
" QuerySQLCheckerTool,\n",
|
|
" QuerySQLDatabaseTool,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c067e0ed-dcca-4dcc-81b2-a0eeb4fc2a9f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Use within an agent\n",
|
|
"\n",
|
|
"Following the [SQL Q&A Tutorial](/docs/tutorials/sql_qa/#agents), below we equip a simple question-answering agent with the tools in our toolkit. First we pull a relevant prompt and populate it with its required parameters:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "eda12f8b-be90-4697-ac84-2ece9e2d1708",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['dialect', 'top_k']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain import hub\n",
|
|
"\n",
|
|
"prompt_template = hub.pull(\"langchain-ai/sql-agent-system-prompt\")\n",
|
|
"\n",
|
|
"assert len(prompt_template.messages) == 1\n",
|
|
"print(prompt_template.input_variables)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "3470ae96-e5e5-4717-a6d6-d7d28c7b7347",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system_message = prompt_template.format(dialect=\"SQLite\", top_k=5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "97930c07-36d1-4137-94ae-fe5ac83ecc44",
|
|
"metadata": {},
|
|
"source": [
|
|
"We then instantiate the agent:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "48bca92c-9b4b-4d5c-bcce-1b239c9e901c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langgraph.prebuilt import create_react_agent\n",
|
|
"\n",
|
|
"agent_executor = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "09fb1845-1105-4f41-98b4-24756452a3e3",
|
|
"metadata": {},
|
|
"source": [
|
|
"And issue it a query:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "39e6d2bf-3194-4aba-854b-63faf919157b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
|
"\n",
|
|
"Which country's customers spent the most?\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_list_tables (call_EBPjyfzqXzFutDn8BklYACLj)\n",
|
|
" Call ID: call_EBPjyfzqXzFutDn8BklYACLj\n",
|
|
" Args:\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_list_tables\n",
|
|
"\n",
|
|
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_schema (call_kGcnKpxRVFIY8dPjYIJbRoVU)\n",
|
|
" Call ID: call_kGcnKpxRVFIY8dPjYIJbRoVU\n",
|
|
" Args:\n",
|
|
" table_names: Customer, Invoice, InvoiceLine\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_schema\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Customer\" (\n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
|
|
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
|
|
"\t\"Company\" NVARCHAR(80), \n",
|
|
"\t\"Address\" NVARCHAR(70), \n",
|
|
"\t\"City\" NVARCHAR(40), \n",
|
|
"\t\"State\" NVARCHAR(40), \n",
|
|
"\t\"Country\" NVARCHAR(40), \n",
|
|
"\t\"PostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Phone\" NVARCHAR(24), \n",
|
|
"\t\"Fax\" NVARCHAR(24), \n",
|
|
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
|
|
"\t\"SupportRepId\" INTEGER, \n",
|
|
"\tPRIMARY KEY (\"CustomerId\"), \n",
|
|
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Customer table:\n",
|
|
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
|
|
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
|
|
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
|
|
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Invoice\" (\n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
|
|
"\t\"BillingAddress\" NVARCHAR(70), \n",
|
|
"\t\"BillingCity\" NVARCHAR(40), \n",
|
|
"\t\"BillingState\" NVARCHAR(40), \n",
|
|
"\t\"BillingCountry\" NVARCHAR(40), \n",
|
|
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceId\"), \n",
|
|
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Invoice table:\n",
|
|
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
|
|
"1\t2\t2021-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
|
|
"2\t4\t2021-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
|
|
"3\t8\t2021-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"InvoiceLine\" (\n",
|
|
"\t\"InvoiceLineId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
|
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\t\"Quantity\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceLineId\"), \n",
|
|
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from InvoiceLine table:\n",
|
|
"InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n",
|
|
"1\t1\t2\t0.99\t1\n",
|
|
"2\t1\t4\t0.99\t1\n",
|
|
"3\t2\t6\t0.99\t1\n",
|
|
"*/\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_query (call_cTfI7OrY64FzJaDd49ILFWw7)\n",
|
|
" Call ID: call_cTfI7OrY64FzJaDd49ILFWw7\n",
|
|
" Args:\n",
|
|
" query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_query\n",
|
|
"\n",
|
|
"[('USA', 523.06)]\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"\n",
|
|
"Customers from the USA spent the most, with a total amount spent of $523.06.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"example_query = \"Which country's customers spent the most?\"\n",
|
|
"\n",
|
|
"events = agent_executor.stream(\n",
|
|
" {\"messages\": [(\"user\", example_query)]},\n",
|
|
" stream_mode=\"values\",\n",
|
|
")\n",
|
|
"for event in events:\n",
|
|
" event[\"messages\"][-1].pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "adbf3d8d-7570-45a5-950f-ce84db5145ab",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also observe the agent recover from an error:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "23c1235c-6d18-43e4-98ab-85b426b53d94",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
|
"\n",
|
|
"Who are the top 3 best selling artists?\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_query (call_xAkvYiRFM7nCMKXsDNvk1OMx)\n",
|
|
" Call ID: call_xAkvYiRFM7nCMKXsDNvk1OMx\n",
|
|
" Args:\n",
|
|
" query: SELECT artist_name, SUM(quantity) AS total_sold FROM sales GROUP BY artist_name ORDER BY total_sold DESC LIMIT 3\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_query\n",
|
|
"\n",
|
|
"Error: (sqlite3.OperationalError) no such table: sales\n",
|
|
"[SQL: SELECT artist_name, SUM(quantity) AS total_sold FROM sales GROUP BY artist_name ORDER BY total_sold DESC LIMIT 3]\n",
|
|
"(Background on this error at: https://sqlalche.me/e/20/e3q8)\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_list_tables (call_K4Zvbowsq7XPgGFepbvc5G7i)\n",
|
|
" Call ID: call_K4Zvbowsq7XPgGFepbvc5G7i\n",
|
|
" Args:\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_list_tables\n",
|
|
"\n",
|
|
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_schema (call_tUztueSK7VO2klZ99xT4ZVhM)\n",
|
|
" Call ID: call_tUztueSK7VO2klZ99xT4ZVhM\n",
|
|
" Args:\n",
|
|
" table_names: Artist, Album, InvoiceLine\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_schema\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Album\" (\n",
|
|
"\t\"AlbumId\" INTEGER NOT NULL, \n",
|
|
"\t\"Title\" NVARCHAR(160) NOT NULL, \n",
|
|
"\t\"ArtistId\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"AlbumId\"), \n",
|
|
"\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Album table:\n",
|
|
"AlbumId\tTitle\tArtistId\n",
|
|
"1\tFor Those About To Rock We Salute You\t1\n",
|
|
"2\tBalls to the Wall\t2\n",
|
|
"3\tRestless and Wild\t2\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Artist\" (\n",
|
|
"\t\"ArtistId\" INTEGER NOT NULL, \n",
|
|
"\t\"Name\" NVARCHAR(120), \n",
|
|
"\tPRIMARY KEY (\"ArtistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Artist table:\n",
|
|
"ArtistId\tName\n",
|
|
"1\tAC/DC\n",
|
|
"2\tAccept\n",
|
|
"3\tAerosmith\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"InvoiceLine\" (\n",
|
|
"\t\"InvoiceLineId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
|
"\t\"UnitPrice\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\t\"Quantity\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceLineId\"), \n",
|
|
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"InvoiceId\") REFERENCES \"Invoice\" (\"InvoiceId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from InvoiceLine table:\n",
|
|
"InvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n",
|
|
"1\t1\t2\t0.99\t1\n",
|
|
"2\t1\t4\t0.99\t1\n",
|
|
"3\t2\t6\t0.99\t1\n",
|
|
"*/\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_query (call_tVtLQIRPmCM6pukgpHFfq86A)\n",
|
|
" Call ID: call_tVtLQIRPmCM6pukgpHFfq86A\n",
|
|
" Args:\n",
|
|
" query: SELECT Artist.Name AS artist_name, SUM(InvoiceLine.Quantity) AS total_sold FROM Artist JOIN Album ON Artist.ArtistId = Album.ArtistId JOIN Track ON Album.AlbumId = Track.AlbumId JOIN InvoiceLine ON Track.TrackId = InvoiceLine.TrackId GROUP BY Artist.Name ORDER BY total_sold DESC LIMIT 3\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_query\n",
|
|
"\n",
|
|
"[('Iron Maiden', 140), ('U2', 107), ('Metallica', 91)]\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"\n",
|
|
"The top 3 best selling artists are:\n",
|
|
"1. Iron Maiden - 140 units sold\n",
|
|
"2. U2 - 107 units sold\n",
|
|
"3. Metallica - 91 units sold\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"example_query = \"Who are the top 3 best selling artists?\"\n",
|
|
"\n",
|
|
"events = agent_executor.stream(\n",
|
|
" {\"messages\": [(\"user\", example_query)]},\n",
|
|
" stream_mode=\"values\",\n",
|
|
")\n",
|
|
"for event in events:\n",
|
|
" event[\"messages\"][-1].pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "73521f1b-be03-44e6-8b27-a9a46ae8e962",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Specific functionality\n",
|
|
"\n",
|
|
"`SQLDatabaseToolkit` implements a [.get_context](https://python.langchain.com/api_reference/community/agent_toolkits/langchain_community.agent_toolkits.sql.toolkit.SQLDatabaseToolkit.html#langchain_community.agent_toolkits.sql.toolkit.SQLDatabaseToolkit.get_context) method as a convenience for use in prompts or other contexts.\n",
|
|
"\n",
|
|
"**⚠️ Disclaimer ⚠️** : The agent may generate insert/update/delete queries. When this is not expected, use a custom prompt or create a SQL users without write permissions.\n",
|
|
"\n",
|
|
"The final user might overload your SQL database by asking a simple question such as \"run the biggest query possible\". The generated query might look like:\n",
|
|
"\n",
|
|
"```sql\n",
|
|
"SELECT * FROM \"public\".\"users\"\n",
|
|
" JOIN \"public\".\"user_permissions\" ON \"public\".\"users\".id = \"public\".\"user_permissions\".user_id\n",
|
|
" JOIN \"public\".\"projects\" ON \"public\".\"users\".id = \"public\".\"projects\".user_id\n",
|
|
" JOIN \"public\".\"events\" ON \"public\".\"projects\".id = \"public\".\"events\".project_id;\n",
|
|
"```\n",
|
|
"\n",
|
|
"For a transactional SQL database, if one of the table above contains millions of rows, the query might cause trouble to other applications using the same database.\n",
|
|
"\n",
|
|
"Most datawarehouse oriented databases support user-level quota, for limiting resource usage."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1aa8a7e3-87ca-4963-a224-0cbdc9d88714",
|
|
"metadata": {},
|
|
"source": [
|
|
"## API reference\n",
|
|
"\n",
|
|
"For detailed documentation of all SQLDatabaseToolkit features and configurations head to the [API reference](https://python.langchain.com/api_reference/community/agent_toolkits/langchain_community.agent_toolkits.sql.toolkit.SQLDatabaseToolkit.html)."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|