{ "cells": [ { "cell_type": "markdown", "id": "fc935871-7640-41c6-b798-58514d860fe0", "metadata": {}, "source": [ "## LLaMA2 chat with SQL\n", " \n", "This Cookbook shows how to use LangChain's `SQLDatabaseChain` with LLaMA2 to chat about structured data stored in a SQL DB. \n", "\n", "* As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n", "\n", "* Because the SQLDatabaseChain API implementation is still in the langchain_experimental package, you'll see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others." ] }, { "cell_type": "code", "execution_count": null, "id": "81adcf8b-395a-4f02-8749-ac976942b446", "metadata": {}, "outputs": [], "source": [ "! pip install langchain replicate langchain_experimental" ] }, { "cell_type": "markdown", "id": "8e13ed66-300b-4a23-b8ac-44df68ee4733", "metadata": {}, "source": [ "## LLM\n", "\n", "Use Replicate API for llama-2-13b-chat." ] }, { "cell_type": "code", "execution_count": 27, "id": "416ecce7-8aec-4145-b3f1-587a9b8a4fe9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Init param `input` is deprecated, please use `model_kwargs` instead.\n" ] } ], "source": [ "from getpass import getpass\n", "from langchain.llms import Replicate\n", "# REPLICATE_API_TOKEN = getpass()\n", "# os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n", "\n", "# Replicate API\n", "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n", "\n", "# Set the system_prompt so that LLaMA will generate only the SQL statement, instead of being wordy and adding something like\n", "# \"Sure! Here's the SQL query for the given input question: \" before the SQL query; otherwise custom parsing will be needed.\n", "llm = Replicate(\n", " model=llama2_13b_chat,\n", " input={\"temperature\": 0.01, \n", " \"max_length\": 500, \n", " \"top_p\": 1}\n", ")" ] }, { "cell_type": "markdown", "id": "80222165-f353-4e35-a123-5f70fd70c6c8", "metadata": {}, "source": [ "## DB\n", "\n", "Connect to a SQL DB, which in this case is in this same directory." ] }, { "cell_type": "code", "execution_count": 28, "id": "025bdd82-3bb1-4948-bc7c-c3ccd94fd05c", "metadata": {}, "outputs": [], "source": [ "from langchain.utilities import SQLDatabase\n", "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n", "\n", "def get_schema(_):\n", " return db.get_table_info()\n", "\n", "def run_query(query):\n", " return db.run(query)" ] }, { "cell_type": "markdown", "id": "654b3577-baa2-4e12-a393-f40e5db49ac7", "metadata": {}, "source": [ "## Query a SQL DB \n", "\n", "Follow the workflow [here](https://python.langchain.com/docs/expression_language/cookbook/sql_db)." ] }, { "cell_type": "code", "execution_count": 72, "id": "5a4933ea-d9c0-4b0a-8177-ba4490c6532b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\" SELECT * FROM nba_roster WHERE NAME = 'Klay Thompson';\"" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Prompt\n", "from langchain.prompts import ChatPromptTemplate\n", "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n", "{schema}\n", "\n", "Question: {question}\n", "SQL Query:\"\"\"\n", "prompt = ChatPromptTemplate.from_messages([\n", " (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n", " (\"human\", template)\n", "])\n", "\n", "# Chain to query\n", "from langchain.chat_models import ChatOpenAI\n", "from langchain.schema.output_parser import StrOutputParser\n", "from langchain.schema.runnable import RunnablePassthrough\n", "\n", "sql_response = (\n", " RunnablePassthrough.assign(schema=get_schema)\n", " | prompt\n", " | llm.bind(stop=[\"\\nSQLResult:\"])\n", " | StrOutputParser()\n", " )\n", "\n", "sql_response.invoke({\"question\": \"What team is Klay Thompson on?\"})" ] }, { "cell_type": "markdown", "id": "a0e9e2c8-9b88-4853-ac86-001bc6cc6695", "metadata": {}, "source": [ "The [LangSmith trace](https://smith.langchain.com/public/afa56a06-b4e2-469a-a60f-c1746e75e42b/r) gives us visibility into the chain! " ] }, { "cell_type": "code", "execution_count": 68, "id": "2a2825e3-c1b6-4f7d-b9c9-d9835de323bb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\" Sure! Here's the natural language response based on the given SQL query and response:\\n\\nThere are 30 unique teams in the NBA roster.\"" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Chain to answer\n", "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n", "{schema}\n", "\n", "Question: {question}\n", "SQL Query: {query}\n", "SQL Response: {response}\"\"\"\n", "prompt_response = ChatPromptTemplate.from_messages([\n", " (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n", " (\"human\", template)\n", "])\n", "\n", "full_chain = (\n", " RunnablePassthrough.assign(query=sql_response) \n", " | RunnablePassthrough.assign(\n", " schema=get_schema,\n", " response=lambda x: db.run(x[\"query\"]),\n", " )\n", " | prompt_response \n", " | llm\n", ")\n", "\n", "full_chain.invoke({\"question\": \"How many unique teams are there?\"})" ] }, { "cell_type": "markdown", "id": "ec17b3ee-6618-4681-b6df-089bbb5ffcd7", "metadata": {}, "source": [ "Again, the [LangSmith trace](https://smith.langchain.com/public/10420721-746a-4806-8ecf-d6dc6399d739/r) gives us visibility into the chain! " ] }, { "cell_type": "markdown", "id": "1e85381b-1edc-4bb3-a7bd-2ab23f81e54d", "metadata": {}, "source": [ "## Chat with a SQL DB \n", "\n", "Add memory!" ] }, { "cell_type": "code", "execution_count": 74, "id": "1985aa1c-eb8f-4fb1-a54f-c8aa10744687", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"SELECT Team \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\"" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Prompt\n", "from langchain.prompts import ChatPromptTemplate\n", "template = \"\"\"Based on the table schema below, write a SQL query that would answer the user's question:\n", "{schema}\n", "\n", "Question: {question}\n", "SQL Query:\"\"\"\n", "prompt = ChatPromptTemplate.from_messages([\n", " (\"system\", \"Given an input question, convert it to a SQL query. No pre-amble.\"),\n", " MessagesPlaceholder(variable_name=\"history\"),\n", " (\"human\", template)\n", "])\n", "\n", "memory = ConversationBufferMemory(return_messages=True)\n", "\n", "# Chain to query with memory \n", "from langchain.memory import ConversationBufferMemory\n", "from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n", "from langchain.schema.runnable import RunnableLambda, GetLocalVar, PutLocalVar\n", "\n", "sql_chain = (\n", " RunnablePassthrough.assign(\n", " schema=get_schema,\n", " history=RunnableLambda(lambda x: memory.load_memory_variables(x)[\"history\"])\n", " )| prompt\n", " | model.bind(stop=[\"\\nSQLResult:\"])\n", " | StrOutputParser()\n", ")\n", "\n", "def save(input_output):\n", " output = {\"output\": input_output.pop(\"output\")}\n", " memory.save_context(input_output, output)\n", " return output['output']\n", " \n", "sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save\n", "sql_response_memory.invoke({\"question\": \"What team is Klay Thompson on?\"})" ] }, { "cell_type": "code", "execution_count": 75, "id": "0b45818a-1498-441d-b82d-23c29428c2bb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"SELECT SALARY \\nFROM nba_roster \\nWHERE NAME = 'Klay Thompson'\"" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sql_response_memory.invoke({\"question\": \"What is his salary?\"})" ] }, { "cell_type": "code", "execution_count": 76, "id": "800a7a3b-f411-478b-af51-2310cd6e0425", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\" Sure thing! Here's the natural language response based on the given SQL query and response:\\n\\nKlay Thompson plays for the Golden State Warriors.\"" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Chain to answer\n", "template = \"\"\"Based on the table schema below, question, sql query, and sql response, write a natural language response:\n", "{schema}\n", "\n", "Question: {question}\n", "SQL Query: {query}\n", "SQL Response: {response}\"\"\"\n", "prompt_response = ChatPromptTemplate.from_messages([\n", " (\"system\", \"Given an input question and SQL response, convert it to a natural langugae answer. No pre-amble.\"),\n", " (\"human\", template)\n", "])\n", "\n", "full_chain = (\n", " RunnablePassthrough.assign(query=sql_response_memory) \n", " | RunnablePassthrough.assign(\n", " schema=get_schema,\n", " response=lambda x: db.run(x[\"query\"]),\n", " )\n", " | prompt_response \n", " | llm\n", ")\n", "\n", "full_chain.invoke({\"question\": \"What team is Klay Thompson on?\"})" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }