mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-02 15:20:02 +00:00
Trying to unblock documentation build pipeline * Bump langgraph dep in docs * Update langgraph in lock file (resolves an issue in API reference generation)
1284 lines
78 KiB
Plaintext
1284 lines
78 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Build a Question/Answering system over SQL data\n",
|
|
"\n",
|
|
":::info Prerequisites\n",
|
|
"\n",
|
|
"This guide assumes familiarity with the following concepts:\n",
|
|
"\n",
|
|
"- [Chat models](/docs/concepts/chat_models)\n",
|
|
"- [Tools](/docs/concepts/tools)\n",
|
|
"- [Agents](/docs/concepts/agents)\n",
|
|
"- [LangGraph](/docs/concepts/architecture/#langgraph)\n",
|
|
"\n",
|
|
":::\n",
|
|
"\n",
|
|
"Enabling a LLM system to query structured data can be qualitatively different from unstructured text data. Whereas in the latter it is common to generate text that can be searched against a vector database, the approach for structured data is often for the LLM to write and execute queries in a DSL, such as SQL. In this guide we'll go over the basic ways to create a Q&A system over tabular data in databases. We will cover implementations using both [chains](/docs/tutorials/sql_qa#chains) and [agents](/docs/tutorials/sql_qa#agents). These systems will allow us to ask a question about the data in a database and get back a natural language answer. The main difference between the two is that our agent can query the database in a loop as many times as it needs to answer the question.\n",
|
|
"\n",
|
|
"## ⚠️ Security note ⚠️\n",
|
|
"\n",
|
|
"Building Q&A systems of SQL databases requires executing model-generated SQL queries. There are inherent risks in doing this. Make sure that your database connection permissions are always scoped as narrowly as possible for your chain/agent's needs. This will mitigate though not eliminate the risks of building a model-driven system. For more on general security best practices, [see here](/docs/security).\n",
|
|
"\n",
|
|
"\n",
|
|
"## Architecture\n",
|
|
"\n",
|
|
"At a high-level, the steps of these systems are:\n",
|
|
"\n",
|
|
"1. **Convert question to SQL query**: Model converts user input to a SQL query.\n",
|
|
"2. **Execute SQL query**: Execute the query.\n",
|
|
"3. **Answer the question**: Model responds to user input using the query results.\n",
|
|
"\n",
|
|
"Note that querying data in CSVs can follow a similar approach. See our [how-to guide](/docs/how_to/sql_csv) on question-answering over CSV data for more detail.\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"## Setup\n",
|
|
"\n",
|
|
"First, get required packages and set environment variables:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%capture --no-stderr\n",
|
|
"%pip install --upgrade --quiet langchain-community langgraph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"```python\n",
|
|
"# Comment out the below to opt-out of using LangSmith in this notebook. Not required.\n",
|
|
"if not os.environ.get(\"LANGSMITH_API_KEY\"):\n",
|
|
" os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass()\n",
|
|
" os.environ[\"LANGSMITH_TRACING\"] = \"true\"\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Sample data\n",
|
|
"\n",
|
|
"The below example will use a SQLite connection with the Chinook database, which is a sample database that represents a digital media store. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook. You can also download and build the database via the command line:\n",
|
|
"```bash\n",
|
|
"curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db\n",
|
|
"```\n",
|
|
"\n",
|
|
"Now, `Chinook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"sqlite\n",
|
|
"['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"\"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]\""
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain_community.utilities import SQLDatabase\n",
|
|
"\n",
|
|
"db = SQLDatabase.from_uri(\"sqlite:///Chinook.db\")\n",
|
|
"print(db.dialect)\n",
|
|
"print(db.get_usable_table_names())\n",
|
|
"db.run(\"SELECT * FROM Artist LIMIT 10;\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Great! We've got a SQL database that we can query. Now let's try hooking it up to an LLM.\n",
|
|
"\n",
|
|
"## Chains {#chains}\n",
|
|
"\n",
|
|
"Chains are compositions of predictable steps. In [LangGraph](/docs/concepts/architecture/#langgraph), we can represent a chain via simple sequence of nodes. Let's create a sequence of steps that, given a question, does the following:\n",
|
|
"- converts the question into a SQL query;\n",
|
|
"- executes the query;\n",
|
|
"- uses the result to answer the original question.\n",
|
|
"\n",
|
|
"There are scenarios not supported by this arrangement. For example, this system will execute a SQL query for any user input-- even \"hello\". Importantly, as we'll see below, some questions require more than one query to answer. We will address these scenarios in the Agents section.\n",
|
|
"\n",
|
|
"### Application state\n",
|
|
"\n",
|
|
"The LangGraph [state](https://langchain-ai.github.io/langgraph/concepts/low_level/#state) of our application controls what data is input to the application, transferred between steps, and output by the application. It is typically a `TypedDict`, but can also be a [Pydantic BaseModel](https://langchain-ai.github.io/langgraph/how-tos/state-model/).\n",
|
|
"\n",
|
|
"For this application, we can just keep track of the input question, generated query, query result, and generated answer:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing_extensions import TypedDict\n",
|
|
"\n",
|
|
"\n",
|
|
"class State(TypedDict):\n",
|
|
" question: str\n",
|
|
" query: str\n",
|
|
" result: str\n",
|
|
" answer: str"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now we just need functions that operate on this state and populate its contents.\n",
|
|
"\n",
|
|
"### Convert question to SQL query\n",
|
|
"\n",
|
|
"The first step is to take the user input and convert it to a SQL query. To reliably obtain SQL queries (absent markdown formatting and explanations or clarifications), we will make use of LangChain's [structured output](/docs/concepts/structured_outputs/) abstraction.\n",
|
|
"\n",
|
|
"Let's select a chat model for our application:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
|
"\n",
|
|
"<ChatModelTabs customVarName=\"llm\" />\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# | output: false\n",
|
|
"# | echo: false\n",
|
|
"\n",
|
|
"from langchain_openai import ChatOpenAI\n",
|
|
"\n",
|
|
"llm = ChatOpenAI(model=\"gpt-4o\", temperature=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's provide some instructions for our model:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"================================\u001b[1m System Message \u001b[0m================================\n",
|
|
"\n",
|
|
"\n",
|
|
"Given an input question, create a syntactically correct \u001b[33;1m\u001b[1;3m{dialect}\u001b[0m query to\n",
|
|
"run to help find the answer. Unless the user specifies in his question a\n",
|
|
"specific number of examples they wish to obtain, always limit your query to\n",
|
|
"at most \u001b[33;1m\u001b[1;3m{top_k}\u001b[0m results. You can order the results by a relevant column to\n",
|
|
"return the most interesting examples in the database.\n",
|
|
"\n",
|
|
"Never query for all the columns from a specific table, only ask for a the\n",
|
|
"few relevant columns given the question.\n",
|
|
"\n",
|
|
"Pay attention to use only the column names that you can see in the schema\n",
|
|
"description. Be careful to not query for columns that do not exist. Also,\n",
|
|
"pay attention to which column is in which table.\n",
|
|
"\n",
|
|
"Only use the following tables:\n",
|
|
"\u001b[33;1m\u001b[1;3m{table_info}\u001b[0m\n",
|
|
"\n",
|
|
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
|
"\n",
|
|
"Question: \u001b[33;1m\u001b[1;3m{input}\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
|
"\n",
|
|
"system_message = \"\"\"\n",
|
|
"Given an input question, create a syntactically correct {dialect} query to\n",
|
|
"run to help find the answer. Unless the user specifies in his question a\n",
|
|
"specific number of examples they wish to obtain, always limit your query to\n",
|
|
"at most {top_k} results. You can order the results by a relevant column to\n",
|
|
"return the most interesting examples in the database.\n",
|
|
"\n",
|
|
"Never query for all the columns from a specific table, only ask for a the\n",
|
|
"few relevant columns given the question.\n",
|
|
"\n",
|
|
"Pay attention to use only the column names that you can see in the schema\n",
|
|
"description. Be careful to not query for columns that do not exist. Also,\n",
|
|
"pay attention to which column is in which table.\n",
|
|
"\n",
|
|
"Only use the following tables:\n",
|
|
"{table_info}\n",
|
|
"\"\"\"\n",
|
|
"\n",
|
|
"user_prompt = \"Question: {input}\"\n",
|
|
"\n",
|
|
"query_prompt_template = ChatPromptTemplate(\n",
|
|
" [(\"system\", system_message), (\"user\", user_prompt)]\n",
|
|
")\n",
|
|
"\n",
|
|
"for message in query_prompt_template.messages:\n",
|
|
" message.pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The prompt includes several parameters we will need to populate, such as the SQL dialect and table schemas. LangChain's [SQLDatabase](https://python.langchain.com/api_reference/community/utilities/langchain_community.utilities.sql_database.SQLDatabase.html) object includes methods to help with this. Our `write_query` step will just populate these parameters and prompt a model to generate the SQL query:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing_extensions import Annotated\n",
|
|
"\n",
|
|
"\n",
|
|
"class QueryOutput(TypedDict):\n",
|
|
" \"\"\"Generated SQL query.\"\"\"\n",
|
|
"\n",
|
|
" query: Annotated[str, ..., \"Syntactically valid SQL query.\"]\n",
|
|
"\n",
|
|
"\n",
|
|
"def write_query(state: State):\n",
|
|
" \"\"\"Generate SQL query to fetch information.\"\"\"\n",
|
|
" prompt = query_prompt_template.invoke(\n",
|
|
" {\n",
|
|
" \"dialect\": db.dialect,\n",
|
|
" \"top_k\": 10,\n",
|
|
" \"table_info\": db.get_table_info(),\n",
|
|
" \"input\": state[\"question\"],\n",
|
|
" }\n",
|
|
" )\n",
|
|
" structured_llm = llm.with_structured_output(QueryOutput)\n",
|
|
" result = structured_llm.invoke(prompt)\n",
|
|
" return {\"query\": result[\"query\"]}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's test it out:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'query': 'SELECT COUNT(*) as employee_count FROM Employee;'}"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"write_query({\"question\": \"How many Employees are there?\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Execute query\n",
|
|
"\n",
|
|
"**This is the most dangerous part of creating a SQL chain.** Consider carefully if it is OK to run automated queries over your data. Minimize the database connection permissions as much as possible. Consider adding a human approval step to you chains before query execution (see below).\n",
|
|
"\n",
|
|
"To execute the query, we will load a tool from [langchain-community](/docs/concepts/architecture/#langchain-community). Our `execute_query` node will just wrap this tool:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool\n",
|
|
"\n",
|
|
"\n",
|
|
"def execute_query(state: State):\n",
|
|
" \"\"\"Execute SQL query.\"\"\"\n",
|
|
" execute_query_tool = QuerySQLDatabaseTool(db=db)\n",
|
|
" return {\"result\": execute_query_tool.invoke(state[\"query\"])}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Testing this step:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'result': '[(8,)]'}"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"execute_query({\"query\": \"SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;\"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Generate answer\n",
|
|
"\n",
|
|
"Finally, our last step generates an answer to the question given the information pulled from the database:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def generate_answer(state: State):\n",
|
|
" \"\"\"Answer question using retrieved information as context.\"\"\"\n",
|
|
" prompt = (\n",
|
|
" \"Given the following user question, corresponding SQL query, \"\n",
|
|
" \"and SQL result, answer the user question.\\n\\n\"\n",
|
|
" f\"Question: {state['question']}\\n\"\n",
|
|
" f\"SQL Query: {state['query']}\\n\"\n",
|
|
" f\"SQL Result: {state['result']}\"\n",
|
|
" )\n",
|
|
" response = llm.invoke(prompt)\n",
|
|
" return {\"answer\": response.content}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Orchestrating with LangGraph\n",
|
|
"\n",
|
|
"Finally, we compile our application into a single `graph` object. In this case, we are just connecting the three steps into a single sequence."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langgraph.graph import START, StateGraph\n",
|
|
"\n",
|
|
"graph_builder = StateGraph(State).add_sequence(\n",
|
|
" [write_query, execute_query, generate_answer]\n",
|
|
")\n",
|
|
"graph_builder.add_edge(START, \"write_query\")\n",
|
|
"graph = graph_builder.compile()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"LangGraph also comes with built-in utilities for visualizing the control flow of your application:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<IPython.core.display.Image object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from IPython.display import Image, display\n",
|
|
"\n",
|
|
"display(Image(graph.get_graph().draw_mermaid_png()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's test our application! Note that we can stream the results of individual steps:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'write_query': {'query': 'SELECT COUNT(*) as employee_count FROM Employee;'}}\n",
|
|
"{'execute_query': {'result': '[(8,)]'}}\n",
|
|
"{'generate_answer': {'answer': 'There are 8 employees in total.'}}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for step in graph.stream(\n",
|
|
" {\"question\": \"How many employees are there?\"}, stream_mode=\"updates\"\n",
|
|
"):\n",
|
|
" print(step)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Check out the [LangSmith trace](https://smith.langchain.com/public/30a79380-6ba6-46af-8bd9-5d1df0b9ccca/r)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Human-in-the-loop\n",
|
|
"\n",
|
|
"LangGraph supports a number of features that can be useful for this workflow. One of them is [human-in-the-loop](https://langchain-ai.github.io/langgraph/concepts/human_in_the_loop/): we can interrupt our application before sensitive steps (such as the execution of a SQL query) for human review. This is enabled by LangGraph's [persistence](https://langchain-ai.github.io/langgraph/concepts/persistence/) layer, which saves run progress to your storage of choice. Below, we specify storage in-memory:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langgraph.checkpoint.memory import MemorySaver\n",
|
|
"\n",
|
|
"memory = MemorySaver()\n",
|
|
"graph = graph_builder.compile(checkpointer=memory, interrupt_before=[\"execute_query\"])\n",
|
|
"\n",
|
|
"# Now that we're using persistence, we need to specify a thread ID\n",
|
|
"# so that we can continue the run after review.\n",
|
|
"config = {\"configurable\": {\"thread_id\": \"1\"}}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<IPython.core.display.Image object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"display(Image(graph.get_graph().draw_mermaid_png()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's repeat the same run, adding in a simple yes/no approval step:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'write_query': {'query': 'SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;'}}\n",
|
|
"{'__interrupt__': ()}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdin",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Do you want to go to execute query? (yes/no): yes\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'execute_query': {'result': '[(8,)]'}}\n",
|
|
"{'generate_answer': {'answer': 'There are 8 employees.'}}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for step in graph.stream(\n",
|
|
" {\"question\": \"How many employees are there?\"},\n",
|
|
" config,\n",
|
|
" stream_mode=\"updates\",\n",
|
|
"):\n",
|
|
" print(step)\n",
|
|
"\n",
|
|
"try:\n",
|
|
" user_approval = input(\"Do you want to go to execute query? (yes/no): \")\n",
|
|
"except Exception:\n",
|
|
" user_approval = \"no\"\n",
|
|
"\n",
|
|
"if user_approval.lower() == \"yes\":\n",
|
|
" # If approved, continue the graph execution\n",
|
|
" for step in graph.stream(None, config, stream_mode=\"updates\"):\n",
|
|
" print(step)\n",
|
|
"else:\n",
|
|
" print(\"Operation cancelled by user.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"See [this](https://langchain-ai.github.io/langgraph/concepts/human_in_the_loop/) LangGraph guide for more detail and examples."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Next steps\n",
|
|
"\n",
|
|
"For more complex query-generation, we may want to create few-shot prompts or add query-checking steps. For advanced techniques like this and more check out:\n",
|
|
"\n",
|
|
"* [Prompting strategies](/docs/how_to/sql_prompting): Advanced prompt engineering techniques.\n",
|
|
"* [Query checking](/docs/how_to/sql_query_checking): Add query validation and error handling.\n",
|
|
"* [Large databases](/docs/how_to/sql_large_db): Techniques for working with large databases."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Agents {#agents}\n",
|
|
"\n",
|
|
"[Agents](/docs/concepts/agents) leverage the reasoning capabilities of LLMs to make decisions during execution. Using agents allows you to offload additional discretion over the query generation and execution process. Although their behavior is less predictable than the above \"chain\", they feature some advantages:\n",
|
|
"\n",
|
|
"- They can query the database as many times as needed to answer the user question.\n",
|
|
"- They can recover from errors by running a generated query, catching the traceback and regenerating it correctly.\n",
|
|
"- They can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table).\n",
|
|
"\n",
|
|
"\n",
|
|
"Below we assemble a minimal SQL agent. We will equip it with a set of tools using LangChain's [SQLDatabaseToolkit](https://python.langchain.com/api_reference/community/agent_toolkits/langchain_community.agent_toolkits.sql.toolkit.SQLDatabaseToolkit.html). Using LangGraph's [pre-built ReAct agent constructor](https://langchain-ai.github.io/langgraph/how-tos/#langgraph.prebuilt.chat_agent_executor.create_react_agent), we can do this in one line.\n",
|
|
"\n",
|
|
":::tip\n",
|
|
"\n",
|
|
"Check out LangGraph's [SQL Agent Tutorial](https://langchain-ai.github.io/langgraph/tutorials/sql-agent/) for a more advanced formulation of a SQL agent.\n",
|
|
"\n",
|
|
":::\n",
|
|
"\n",
|
|
"The `SQLDatabaseToolkit` includes tools that can:\n",
|
|
"\n",
|
|
"* Create and execute queries\n",
|
|
"* Check query syntax\n",
|
|
"* Retrieve table descriptions\n",
|
|
"* ... and more"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[QuerySQLDatabaseTool(description=\"Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.\", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),\n",
|
|
" InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),\n",
|
|
" ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>),\n",
|
|
" QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10d5f9120>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x119315480>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x119317550>, root_client=<openai.OpenAI object at 0x10d5f8df0>, root_async_client=<openai.AsyncOpenAI object at 0x1193154e0>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\\n{query}\\nDouble check the {dialect} query above for common mistakes, including:\\n- Using NOT IN with NULL values\\n- Using UNION when UNION ALL should have been used\\n- Using BETWEEN for exclusive ranges\\n- Data type mismatch in predicates\\n- Properly quoting identifiers\\n- Using the correct number of arguments for functions\\n- Casting to the correct data type\\n- Using the proper columns for joins\\n\\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\\n\\nOutput the final SQL query only.\\n\\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x119315480>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x119317550>, root_client=<openai.OpenAI object at 0x10d5f8df0>, root_async_client=<openai.AsyncOpenAI object at 0x1193154e0>, model_name='gpt-4o', temperature=0.0, model_kwargs={}, openai_api_key=SecretStr('**********')), output_parser=StrOutputParser(), llm_kwargs={}))]"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
|
|
"\n",
|
|
"toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n",
|
|
"\n",
|
|
"tools = toolkit.get_tools()\n",
|
|
"\n",
|
|
"tools"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### System Prompt\n",
|
|
"\n",
|
|
"We will also want to load a system prompt for our agent. This will consist of instructions for how to behave. Note that the prompt below has several parameters, which we assign below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"system_message = \"\"\"\n",
|
|
"You are an agent designed to interact with a SQL database.\n",
|
|
"Given an input question, create a syntactically correct {dialect} query to run,\n",
|
|
"then look at the results of the query and return the answer. Unless the user\n",
|
|
"specifies a specific number of examples they wish to obtain, always limit your\n",
|
|
"query to at most {top_k} results.\n",
|
|
"\n",
|
|
"You can order the results by a relevant column to return the most interesting\n",
|
|
"examples in the database. Never query for all the columns from a specific table,\n",
|
|
"only ask for the relevant columns given the question.\n",
|
|
"\n",
|
|
"You MUST double check your query before executing it. If you get an error while\n",
|
|
"executing a query, rewrite the query and try again.\n",
|
|
"\n",
|
|
"DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the\n",
|
|
"database.\n",
|
|
"\n",
|
|
"To start you should ALWAYS look at the tables in the database to see what you\n",
|
|
"can query. Do NOT skip this step.\n",
|
|
"\n",
|
|
"Then you should query the schema of the most relevant tables.\n",
|
|
"\"\"\".format(\n",
|
|
" dialect=\"SQLite\",\n",
|
|
" top_k=5,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Initializing agent\n",
|
|
"\n",
|
|
"We will use a prebuilt [LangGraph](/docs/concepts/architecture/#langgraph) agent to build our agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain_core.messages import HumanMessage\n",
|
|
"from langgraph.prebuilt import create_react_agent\n",
|
|
"\n",
|
|
"agent_executor = create_react_agent(llm, tools, prompt=system_message)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Consider how the agent responds to the below question:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
|
"\n",
|
|
"Which country's customers spent the most?\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_list_tables (call_tFp7HYD6sAAmCShgeqkVZH6Q)\n",
|
|
" Call ID: call_tFp7HYD6sAAmCShgeqkVZH6Q\n",
|
|
" Args:\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_list_tables\n",
|
|
"\n",
|
|
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_schema (call_KJZ1Jx6JazyDdJa0uH1UeiOz)\n",
|
|
" Call ID: call_KJZ1Jx6JazyDdJa0uH1UeiOz\n",
|
|
" Args:\n",
|
|
" table_names: Customer, Invoice\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_schema\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Customer\" (\n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"FirstName\" NVARCHAR(40) NOT NULL, \n",
|
|
"\t\"LastName\" NVARCHAR(20) NOT NULL, \n",
|
|
"\t\"Company\" NVARCHAR(80), \n",
|
|
"\t\"Address\" NVARCHAR(70), \n",
|
|
"\t\"City\" NVARCHAR(40), \n",
|
|
"\t\"State\" NVARCHAR(40), \n",
|
|
"\t\"Country\" NVARCHAR(40), \n",
|
|
"\t\"PostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Phone\" NVARCHAR(24), \n",
|
|
"\t\"Fax\" NVARCHAR(24), \n",
|
|
"\t\"Email\" NVARCHAR(60) NOT NULL, \n",
|
|
"\t\"SupportRepId\" INTEGER, \n",
|
|
"\tPRIMARY KEY (\"CustomerId\"), \n",
|
|
"\tFOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Customer table:\n",
|
|
"CustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n",
|
|
"1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n",
|
|
"2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n",
|
|
"3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Invoice\" (\n",
|
|
"\t\"InvoiceId\" INTEGER NOT NULL, \n",
|
|
"\t\"CustomerId\" INTEGER NOT NULL, \n",
|
|
"\t\"InvoiceDate\" DATETIME NOT NULL, \n",
|
|
"\t\"BillingAddress\" NVARCHAR(70), \n",
|
|
"\t\"BillingCity\" NVARCHAR(40), \n",
|
|
"\t\"BillingState\" NVARCHAR(40), \n",
|
|
"\t\"BillingCountry\" NVARCHAR(40), \n",
|
|
"\t\"BillingPostalCode\" NVARCHAR(10), \n",
|
|
"\t\"Total\" NUMERIC(10, 2) NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"InvoiceId\"), \n",
|
|
"\tFOREIGN KEY(\"CustomerId\") REFERENCES \"Customer\" (\"CustomerId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Invoice table:\n",
|
|
"InvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n",
|
|
"1\t2\t2021-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n",
|
|
"2\t4\t2021-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n",
|
|
"3\t8\t2021-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n",
|
|
"*/\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_query_checker (call_AQuTGbgH63u4gPgyV723yrjX)\n",
|
|
" Call ID: call_AQuTGbgH63u4gPgyV723yrjX\n",
|
|
" Args:\n",
|
|
" query: SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_query_checker\n",
|
|
"\n",
|
|
"```sql\n",
|
|
"SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;\n",
|
|
"```\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_query (call_B88EwU44nwwpQL5M9nlcemSU)\n",
|
|
" Call ID: call_B88EwU44nwwpQL5M9nlcemSU\n",
|
|
" Args:\n",
|
|
" query: SELECT c.Country, SUM(i.Total) as TotalSpent FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 1;\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_query\n",
|
|
"\n",
|
|
"[('USA', 523.06)]\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"\n",
|
|
"The country whose customers spent the most is the USA, with a total spending of 523.06.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"question = \"Which country's customers spent the most?\"\n",
|
|
"\n",
|
|
"for step in agent_executor.stream(\n",
|
|
" {\"messages\": [{\"role\": \"user\", \"content\": question}]},\n",
|
|
" stream_mode=\"values\",\n",
|
|
"):\n",
|
|
" step[\"messages\"][-1].pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"You can also use the [LangSmith trace](https://smith.langchain.com/public/8af422aa-b651-4bfe-8683-e2a7f4ccd82c/r) to visualize these steps and associated metadata.\n",
|
|
"\n",
|
|
"Note that the agent executes multiple queries until it has the information it needs:\n",
|
|
"1. List available tables;\n",
|
|
"2. Retrieves the schema for three tables;\n",
|
|
"3. Queries multiple of the tables via a join operation.\n",
|
|
"\n",
|
|
"The agent is then able to use the result of the final query to generate an answer to the original question.\n",
|
|
"\n",
|
|
"The agent can similarly handle qualitative questions:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
|
"\n",
|
|
"Describe the playlisttrack table\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_list_tables (call_fMF8eTmX5TJDJjc3Mhdg52TI)\n",
|
|
" Call ID: call_fMF8eTmX5TJDJjc3Mhdg52TI\n",
|
|
" Args:\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_list_tables\n",
|
|
"\n",
|
|
"Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_schema (call_W8Vkk4NEodkAAIg8nexAszUH)\n",
|
|
" Call ID: call_W8Vkk4NEodkAAIg8nexAszUH\n",
|
|
" Args:\n",
|
|
" table_names: PlaylistTrack\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_schema\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"PlaylistTrack\" (\n",
|
|
"\t\"PlaylistId\" INTEGER NOT NULL, \n",
|
|
"\t\"TrackId\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"PlaylistId\", \"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"TrackId\") REFERENCES \"Track\" (\"TrackId\"), \n",
|
|
"\tFOREIGN KEY(\"PlaylistId\") REFERENCES \"Playlist\" (\"PlaylistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from PlaylistTrack table:\n",
|
|
"PlaylistId\tTrackId\n",
|
|
"1\t3402\n",
|
|
"1\t3389\n",
|
|
"1\t3390\n",
|
|
"*/\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"\n",
|
|
"The `PlaylistTrack` table is designed to associate tracks with playlists. It has the following structure:\n",
|
|
"\n",
|
|
"- **PlaylistId**: An integer that serves as a foreign key referencing the `Playlist` table. It is part of the composite primary key.\n",
|
|
"- **TrackId**: An integer that serves as a foreign key referencing the `Track` table. It is also part of the composite primary key.\n",
|
|
"\n",
|
|
"The primary key for this table is a composite key consisting of both `PlaylistId` and `TrackId`, ensuring that each track can be uniquely associated with a playlist. The table enforces referential integrity by linking to the `Track` and `Playlist` tables through foreign keys.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"question = \"Describe the playlisttrack table\"\n",
|
|
"\n",
|
|
"for step in agent_executor.stream(\n",
|
|
" {\"messages\": [{\"role\": \"user\", \"content\": question}]},\n",
|
|
" stream_mode=\"values\",\n",
|
|
"):\n",
|
|
" step[\"messages\"][-1].pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Dealing with high-cardinality columns\n",
|
|
"\n",
|
|
"In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. \n",
|
|
"\n",
|
|
"We can achieve this by creating a vector store with all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.\n",
|
|
"\n",
|
|
"First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['In Through The Out Door',\n",
|
|
" 'Transmission',\n",
|
|
" 'Battlestar Galactica (Classic), Season',\n",
|
|
" 'A Copland Celebration, Vol. I',\n",
|
|
" 'Quiet Songs']"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import ast\n",
|
|
"import re\n",
|
|
"\n",
|
|
"\n",
|
|
"def query_as_list(db, query):\n",
|
|
" res = db.run(query)\n",
|
|
" res = [el for sub in ast.literal_eval(res) for el in sub if el]\n",
|
|
" res = [re.sub(r\"\\b\\d+\\b\", \"\", string).strip() for string in res]\n",
|
|
" return list(set(res))\n",
|
|
"\n",
|
|
"\n",
|
|
"artists = query_as_list(db, \"SELECT Name FROM Artist\")\n",
|
|
"albums = query_as_list(db, \"SELECT Title FROM Album\")\n",
|
|
"albums[:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Using this function, we can create a **retriever tool** that the agent can execute at its discretion.\n",
|
|
"\n",
|
|
"Let's select an [embeddings model](/docs/integrations/text_embedding/) and [vector store](/docs/integrations/vectorstores/) for this step:\n",
|
|
"\n",
|
|
"**Select an embedding model**:\n",
|
|
"\n",
|
|
"import EmbeddingTabs from \"@theme/EmbeddingTabs\";\n",
|
|
"\n",
|
|
"<EmbeddingTabs/>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# | output: false\n",
|
|
"# | echo: false\n",
|
|
"\n",
|
|
"from langchain_openai import OpenAIEmbeddings\n",
|
|
"\n",
|
|
"embeddings = OpenAIEmbeddings()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Select a vector store**:\n",
|
|
"\n",
|
|
"import VectorStoreTabs from \"@theme/VectorStoreTabs\";\n",
|
|
"\n",
|
|
"<VectorStoreTabs/>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# | output: false\n",
|
|
"# | echo: false\n",
|
|
"\n",
|
|
"from langchain_core.vectorstores import InMemoryVectorStore\n",
|
|
"\n",
|
|
"vector_store = InMemoryVectorStore(embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can now construct a retrieval tool that can search over relevant proper nouns in the database:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.agents.agent_toolkits import create_retriever_tool\n",
|
|
"\n",
|
|
"_ = vector_store.add_texts(artists + albums)\n",
|
|
"retriever = vector_store.as_retriever(search_kwargs={\"k\": 5})\n",
|
|
"description = (\n",
|
|
" \"Use to look up values to filter on. Input is an approximate spelling \"\n",
|
|
" \"of the proper noun, output is valid proper nouns. Use the noun most \"\n",
|
|
" \"similar to the search.\"\n",
|
|
")\n",
|
|
"retriever_tool = create_retriever_tool(\n",
|
|
" retriever,\n",
|
|
" name=\"search_proper_nouns\",\n",
|
|
" description=description,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's try it out:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Alice In Chains\n",
|
|
"\n",
|
|
"Alanis Morissette\n",
|
|
"\n",
|
|
"Pearl Jam\n",
|
|
"\n",
|
|
"Pearl Jam\n",
|
|
"\n",
|
|
"Audioslave\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(retriever_tool.invoke(\"Alice Chains\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This way, if the agent determines it needs to write a filter based on an artist along the lines of \"Alice Chains\", it can first use the retriever tool to observe relevant values of a column.\n",
|
|
"\n",
|
|
"Putting this together:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Add to system message\n",
|
|
"suffix = (\n",
|
|
" \"If you need to filter on a proper noun like a Name, you must ALWAYS first look up \"\n",
|
|
" \"the filter value using the 'search_proper_nouns' tool! Do not try to \"\n",
|
|
" \"guess at the proper name - use this function to find similar ones.\"\n",
|
|
")\n",
|
|
"\n",
|
|
"system = f\"{system_message}\\n\\n{suffix}\"\n",
|
|
"\n",
|
|
"tools.append(retriever_tool)\n",
|
|
"\n",
|
|
"agent = create_react_agent(llm, tools, prompt=system)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
|
"\n",
|
|
"How many albums does alis in chain have?\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" search_proper_nouns (call_8ryjsRPLAr79mM3Qvnq6gTOH)\n",
|
|
" Call ID: call_8ryjsRPLAr79mM3Qvnq6gTOH\n",
|
|
" Args:\n",
|
|
" query: alis in chain\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: search_proper_nouns\n",
|
|
"\n",
|
|
"Alice In Chains\n",
|
|
"\n",
|
|
"Aisha Duo\n",
|
|
"\n",
|
|
"Xis\n",
|
|
"\n",
|
|
"Da Lama Ao Caos\n",
|
|
"\n",
|
|
"A-Sides\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_list_tables (call_NJjtCpU89MBMplssjn1z0xzq)\n",
|
|
" Call ID: call_NJjtCpU89MBMplssjn1z0xzq\n",
|
|
" Args:\n",
|
|
" search_proper_nouns (call_1BfrueC9koSIyi4OfMu2Ao8q)\n",
|
|
" Call ID: call_1BfrueC9koSIyi4OfMu2Ao8q\n",
|
|
" Args:\n",
|
|
" query: Alice In Chains\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: search_proper_nouns\n",
|
|
"\n",
|
|
"Alice In Chains\n",
|
|
"\n",
|
|
"Pearl Jam\n",
|
|
"\n",
|
|
"Pearl Jam\n",
|
|
"\n",
|
|
"Foo Fighters\n",
|
|
"\n",
|
|
"Soundgarden\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_schema (call_Kn09w9jd9swcNzIZ1b5MlKID)\n",
|
|
" Call ID: call_Kn09w9jd9swcNzIZ1b5MlKID\n",
|
|
" Args:\n",
|
|
" table_names: Album, Artist\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_schema\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Album\" (\n",
|
|
"\t\"AlbumId\" INTEGER NOT NULL, \n",
|
|
"\t\"Title\" NVARCHAR(160) NOT NULL, \n",
|
|
"\t\"ArtistId\" INTEGER NOT NULL, \n",
|
|
"\tPRIMARY KEY (\"AlbumId\"), \n",
|
|
"\tFOREIGN KEY(\"ArtistId\") REFERENCES \"Artist\" (\"ArtistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Album table:\n",
|
|
"AlbumId\tTitle\tArtistId\n",
|
|
"1\tFor Those About To Rock We Salute You\t1\n",
|
|
"2\tBalls to the Wall\t2\n",
|
|
"3\tRestless and Wild\t2\n",
|
|
"*/\n",
|
|
"\n",
|
|
"\n",
|
|
"CREATE TABLE \"Artist\" (\n",
|
|
"\t\"ArtistId\" INTEGER NOT NULL, \n",
|
|
"\t\"Name\" NVARCHAR(120), \n",
|
|
"\tPRIMARY KEY (\"ArtistId\")\n",
|
|
")\n",
|
|
"\n",
|
|
"/*\n",
|
|
"3 rows from Artist table:\n",
|
|
"ArtistId\tName\n",
|
|
"1\tAC/DC\n",
|
|
"2\tAccept\n",
|
|
"3\tAerosmith\n",
|
|
"*/\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"Tool Calls:\n",
|
|
" sql_db_query (call_WkHRiPcBoGN9bc58MIupRHKP)\n",
|
|
" Call ID: call_WkHRiPcBoGN9bc58MIupRHKP\n",
|
|
" Args:\n",
|
|
" query: SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')\n",
|
|
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
|
"Name: sql_db_query\n",
|
|
"\n",
|
|
"[(1,)]\n",
|
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
|
"\n",
|
|
"Alice In Chains has released 1 album in the database.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"question = \"How many albums does alis in chain have?\"\n",
|
|
"\n",
|
|
"for step in agent.stream(\n",
|
|
" {\"messages\": [{\"role\": \"user\", \"content\": question}]},\n",
|
|
" stream_mode=\"values\",\n",
|
|
"):\n",
|
|
" step[\"messages\"][-1].pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"As we can see, both in the streamed steps and in the [LangSmith trace](https://smith.langchain.com/public/1d757ed2-5688-4458-9400-023594e2c5a7/r), the agent used the `search_proper_nouns` tool in order to check how to correctly query the database for this specific artist."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|