{ "cells": [ { "cell_type": "markdown", "id": "f1390152", "metadata": {}, "source": [ "# MRKL\n", "\n", "This notebook showcases using an agent to replicate the MRKL chain." ] }, { "cell_type": "markdown", "id": "39ea3638", "metadata": {}, "source": [ "This uses the example Chinook database.\n", "To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository." ] }, { "cell_type": "code", "execution_count": 8, "id": "ac561cc4", "metadata": {}, "outputs": [], "source": [ "from langchain import LLMMathChain, OpenAI, SerpAPIChain, SQLDatabase, SQLDatabaseChain\n", "from langchain.agents import initialize_agent, Tool" ] }, { "cell_type": "code", "execution_count": 9, "id": "07e96d99", "metadata": {}, "outputs": [], "source": [ "llm = OpenAI(temperature=0)\n", "search = SerpAPIChain()\n", "llm_math_chain = LLMMathChain(llm=llm, verbose=True)\n", "db = SQLDatabase.from_uri(\"sqlite:///../../../notebooks/Chinook.db\")\n", "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n", "tools = [\n", "# Tool(\n", "# name = \"Search\",\n", "# func=search.run,\n", "# description=\"useful for when you need to answer questions about current events\"\n", "# ),\n", "# Tool(\n", "# name=\"Calculator\",\n", "# func=llm_math_chain.run,\n", "# description=\"useful for when you need to answer questions about math\"\n", "# ),\n", " Tool(\n", " name=\"FooBar DB\",\n", " func=db_chain.run,\n", " description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question\"\n", " )\n", "]" ] }, { "cell_type": "code", "execution_count": 10, "id": "a069c4b6", "metadata": {}, "outputs": [], "source": [ "mrkl = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)" ] }, { "cell_type": "code", "execution_count": 11, "id": "356a1396", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "What was our total number of sales and total revenue in 2013?\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n", "Action: FooBar DB\n", "Action Input: sales 2013\u001b[0m\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "sales 2013\n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE InvoiceDate BETWEEN '2013-01-01' AND '2013-12-31'\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find out how many sales and how much revenue we had in 2013\n", "Action: FooBar DB\n", "Action Input: revenue 2013\u001b[0m\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "revenue 2013\n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT SUM(Total) FROM Invoice WHERE strftime('%Y', InvoiceDate) = '2013'\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[(450.58000000000027,)]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m 450.58\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "Observation: \u001b[36;1m\u001b[1;3m 450.58\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: We had 450.58 sales and 450.58 revenue in 2013\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ "'We had 450.58 sales and 450.58 revenue in 2013'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mrkl.run(\"What was our total number of sales and total revenue in 2013?\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "4e9c9b23", "metadata": {}, "outputs": [], "source": [ "from langchain.agents import ZeroShotAgent\n", "from langchain import PromptTemplate\n", "from langchain import LLMChain\n", "TEMPLATE = \"\"\"You are a data engineer answering questions using a SQL database.\n", "\n", "Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. If a previous query produced an error, do NOT try it again.\n", "\n", "Only use the following tables:\n", "\n", "{table_info}\n", "\n", "Use the following format:\n", "\n", "Question: the input question you must answer\n", "Thought: you should always think about what to do\n", "Action: SQLDB\n", "Action Input: the query to run against the SQL database\n", "Observation: the result of the action\n", "... (this Thought/Action/Action Input/Observation can repeat N times)\n", "Thought: I now know the final answer\n", "Final Answer: the final answer to the original input question\n", "\n", "Begin!\n", "\n", "Question: {{input}}\"\"\".format(**{\n", " \"dialect\": db.dialect,\n", " \"table_info\": db.table_info})\n", "prompt = PromptTemplate(template=TEMPLATE, input_variables=[\"input\"])" ] }, { "cell_type": "code", "execution_count": 19, "id": "a6fb1b4f", "metadata": {}, "outputs": [], "source": [ "llm_chain = LLMChain(llm=llm, prompt=prompt)\n", "tools = [\n", " Tool(\"SQLDB\", db.run, \"foo\")\n", "]\n", "agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)" ] }, { "cell_type": "code", "execution_count": 20, "id": "64578802", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "What was our total number of sales and total revenue in 2013?\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find the total number of sales and total revenue\n", "Action: SQLDB\n", "Action Input: SELECT SUM(Total) AS \"Total Sales\", SUM(Total) AS \"Total Revenue\" FROM Invoice WHERE InvoiceDate LIKE '2013%'\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3m[(450.58000000000027, 450.58000000000027)]\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: 450.58\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ "'450.58'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent.run(\"What was our total number of sales and total revenue in 2013?\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "e603cd7d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Olivia Wilde's boyfriend\n", "Action: Search\n", "Action Input: \"Olivia Wilde's boyfriend\"\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find the age of Harry Styles\n", "Action: Search\n", "Action Input: \"Harry Styles age\"\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3m28 years\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 28 to the 0.23 power\n", "Action: Calculator\n", "Action Input: 28^0.23\u001b[0m\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "28^0.23\u001b[32;1m\u001b[1;3m\n", "\n", "```python\n", "print(28**0.23)\n", "```\n", "\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3m2.1520202182226886\n", "\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "Observation: \u001b[33;1m\u001b[1;3mAnswer: 2.1520202182226886\n", "\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: 2.1520202182226886\u001b[0m" ] }, { "data": { "text/plain": [ "'2.1520202182226886'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mrkl.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "a5c07010", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find an album called 'The Storm Before the Calm'\n", "Action: Search\n", "Action Input: \"The Storm Before the Calm album\"\u001b[0m\n", "Observation: \u001b[36;1m\u001b[1;3mThe Storm Before the Calm (stylized in all lowercase) is the tenth (and eighth international) studio album by Canadian-American singer-songwriter Alanis ...\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I need to check if Alanis is in the FooBar database\n", "Action: FooBar DB\n", "Action Input: \"Does Alanis Morissette exist in the FooBar database?\"\u001b[0m\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "Does Alanis Morissette exist in the FooBar database?\n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT * FROM Artist WHERE Name = 'Alanis Morissette'\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[(4, 'Alanis Morissette')]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m Yes\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "Observation: \u001b[38;5;200m\u001b[1;3m Yes\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find out what albums of Alanis's are in the FooBar database\n", "Action: FooBar DB\n", "Action Input: \"What albums by Alanis Morissette are in the FooBar database?\"\u001b[0m\n", "\n", "\u001b[1m> Entering new chain...\u001b[0m\n", "What albums by Alanis Morissette are in the FooBar database?\n", "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Album.Title FROM Album JOIN Artist ON Album.ArtistId = Artist.ArtistId WHERE Artist.Name = 'Alanis Morissette'\u001b[0m\n", "SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n", "Answer:\u001b[32;1m\u001b[1;3m Jagged Little Pill\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "Observation: \u001b[38;5;200m\u001b[1;3m Jagged Little Pill\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: The album is by Alanis Morissette and the albums in the FooBar database by her are Jagged Little Pill\u001b[0m" ] }, { "data": { "text/plain": [ "'The album is by Alanis Morissette and the albums in the FooBar database by her are Jagged Little Pill'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mrkl.run(\"Who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d7c2e6ac", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }