From f10c17c6a45f22afb610141b63b63a84e6aa7e50 Mon Sep 17 00:00:00 2001 From: Lance Martin <122662504+rlancemartin@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:34:37 -0700 Subject: [PATCH] Update SQL templates (#12464) --- templates/sql-llama2/README.md | 16 +++-- templates/sql-llama2/nba_roster.db | 0 templates/sql-llama2/sql_llama2.ipynb | 65 +++++++++++++++++++ templates/sql-llama2/sql_llama2/__init__.py | 2 +- templates/sql-llama2/sql_llama2/chain.py | 10 +-- templates/sql-llamacpp/README.md | 7 -- templates/sql-llamacpp/nba_roster.db | 0 templates/sql-llamacpp/sql-llamacpp.ipynb | 54 +++++++++++++++ .../sql-llamacpp/sql_llamacpp/__init__.py | 2 +- templates/sql-llamacpp/sql_llamacpp/chain.py | 6 +- templates/sql-ollama/README.md | 9 +-- templates/sql-ollama/nba_roster.db | 0 templates/sql-ollama/sql-ollama.ipynb | 54 +++++++++++++++ templates/sql-ollama/sql_ollama/chain.py | 7 +- 14 files changed, 204 insertions(+), 28 deletions(-) delete mode 100644 templates/sql-llama2/nba_roster.db create mode 100644 templates/sql-llama2/sql_llama2.ipynb delete mode 100644 templates/sql-llamacpp/nba_roster.db create mode 100644 templates/sql-llamacpp/sql-llamacpp.ipynb delete mode 100644 templates/sql-ollama/nba_roster.db create mode 100644 templates/sql-ollama/sql-ollama.ipynb diff --git a/templates/sql-llama2/README.md b/templates/sql-llama2/README.md index c3574f066a4..939bea60bb8 100644 --- a/templates/sql-llama2/README.md +++ b/templates/sql-llama2/README.md @@ -8,8 +8,14 @@ But, it can be adapted to any API that support LLaMA2, including [Fireworks](htt See related templates `sql-ollama` and `sql-llamacpp` for private, local chat with SQL. -## Installation -```bash -# from inside your LangServe instance -poe add sql-llama2 -``` +## Set up SQL DB + +This template includes an example DB of 2023 NBA rosters. + +You can see instructions to build this DB [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/StructuredLlama.ipynb). + +## LLM + +This template will use a `Replicate` [hosted version](https://replicate.com/meta/llama-2-13b-chat/versions/f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d) of LLaMA2. + +Be sure that `REPLICATE_API_TOKEN` is set in your environment. \ No newline at end of file diff --git a/templates/sql-llama2/nba_roster.db b/templates/sql-llama2/nba_roster.db deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/templates/sql-llama2/sql_llama2.ipynb b/templates/sql-llama2/sql_llama2.ipynb new file mode 100644 index 00000000000..a8991bf946e --- /dev/null +++ b/templates/sql-llama2/sql_llama2.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "22f3f9f9-80ee-4da1-ba12-105a0ce74203", + "metadata": {}, + "source": [ + "## Run Template\n", + "\n", + "In `server.py`, set -\n", + "```\n", + "add_routes(app, chain, path=\"/sql_llama2\")\n", + "```\n", + "\n", + "This template includes an example DB of 2023 NBA rosters.\n", + "\n", + "We can ask questions related to NBA players. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "4545c603-77ec-4c15-b9c0-a70529eebed0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" Sure thing! Here's the natural language response based on the given SQL query and response:\\n\\nKlay Thompson plays for the Golden State Warriors.\"" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langserve.client import RemoteRunnable\n", + "sql_app = RemoteRunnable('http://0.0.0.0:8001/sql_llama2')\n", + "sql_app.invoke({\"question\": \"What team is Klay Thompson on?\"})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/templates/sql-llama2/sql_llama2/__init__.py b/templates/sql-llama2/sql_llama2/__init__.py index 7e2706adcab..efd92085aba 100644 --- a/templates/sql-llama2/sql_llama2/__init__.py +++ b/templates/sql-llama2/sql_llama2/__init__.py @@ -1,3 +1,3 @@ -from llama2.chain import chain +from sql_llama2.chain import chain __all__ = ["chain"] \ No newline at end of file diff --git a/templates/sql-llama2/sql_llama2/chain.py b/templates/sql-llama2/sql_llama2/chain.py index 7e53f5d147c..7651cc37017 100644 --- a/templates/sql-llama2/sql_llama2/chain.py +++ b/templates/sql-llama2/sql_llama2/chain.py @@ -2,6 +2,7 @@ from pathlib import Path from langchain.llms import Replicate from langchain.prompts import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnablePassthrough from langchain.utilities import SQLDatabase @@ -14,13 +15,11 @@ llm = Replicate( model_kwargs={"temperature": 0.01, "max_length": 500, "top_p": 1}, ) - db_path = Path(__file__).parent / "nba_roster.db" rel = db_path.relative_to(Path.cwd()) db_string = f"sqlite:///{rel}" db = SQLDatabase.from_uri(db_string, sample_rows_in_table_info=0) - def get_schema(_): return db.get_table_info() @@ -28,7 +27,6 @@ def get_schema(_): def run_query(query): return db.run(query) - template_query = """Based on the table schema below, write a SQL query that would answer the user's question: {schema} @@ -66,8 +64,12 @@ prompt_response = ChatPromptTemplate.from_messages( ] ) +# Supply the input types to the prompt +class InputType(BaseModel): + question: str + chain = ( - RunnablePassthrough.assign(query=sql_response) + RunnablePassthrough.assign(query=sql_response).with_types(input_type=InputType) | RunnablePassthrough.assign( schema=get_schema, response=lambda x: db.run(x["query"]), diff --git a/templates/sql-llamacpp/README.md b/templates/sql-llamacpp/README.md index 91c41c7a96b..14dd0df00cd 100644 --- a/templates/sql-llamacpp/README.md +++ b/templates/sql-llamacpp/README.md @@ -27,10 +27,3 @@ You can select other files and specify their download path (browse [here](https: This template includes an example DB of 2023 NBA rosters. You can see instructions to build this DB [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/StructuredLlama.ipynb). - -## Installation - -```bash -# from inside your LangServe instance -poe add sql/llama2-ollama -``` diff --git a/templates/sql-llamacpp/nba_roster.db b/templates/sql-llamacpp/nba_roster.db deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/templates/sql-llamacpp/sql-llamacpp.ipynb b/templates/sql-llamacpp/sql-llamacpp.ipynb new file mode 100644 index 00000000000..b3bf2c3e48c --- /dev/null +++ b/templates/sql-llamacpp/sql-llamacpp.ipynb @@ -0,0 +1,54 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a0314df0-da99-4086-a96f-b14df05b3362", + "metadata": {}, + "source": [ + "## Run Template\n", + "\n", + "In `server.py`, set -\n", + "```\n", + "add_routes(app, chain, path=\"/sql_llamacpp\")\n", + "```\n", + "\n", + "This template includes an example DB of 2023 NBA rosters.\n", + "\n", + "We can ask questions related to NBA players. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff5869c6-2065-48f3-bb43-52a515968276", + "metadata": {}, + "outputs": [], + "source": [ + "from langserve.client import RemoteRunnable\n", + "sql_app = RemoteRunnable('http://0.0.0.0:8001/sql_llamacpp')\n", + "sql_app.invoke({\"question\": \"What team is Klay Thompson on?\"})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/templates/sql-llamacpp/sql_llamacpp/__init__.py b/templates/sql-llamacpp/sql_llamacpp/__init__.py index 376ed218308..7a48c328488 100644 --- a/templates/sql-llamacpp/sql_llamacpp/__init__.py +++ b/templates/sql-llamacpp/sql_llamacpp/__init__.py @@ -1,3 +1,3 @@ -from llamacpp.chain import chain +from sql_llamacpp.chain import chain __all__ = ["chain"] \ No newline at end of file diff --git a/templates/sql-llamacpp/sql_llamacpp/chain.py b/templates/sql-llamacpp/sql_llamacpp/chain.py index 9ec75ed23e5..e7ff025f4f0 100644 --- a/templates/sql-llamacpp/sql_llamacpp/chain.py +++ b/templates/sql-llamacpp/sql_llamacpp/chain.py @@ -6,6 +6,7 @@ import requests from langchain.llms import LlamaCpp from langchain.memory import ConversationBufferMemory from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableLambda, RunnablePassthrough from langchain.utilities import SQLDatabase @@ -113,9 +114,12 @@ prompt_response = ChatPromptTemplate.from_messages( ("human", template), ] ) +# Supply the input types to the prompt +class InputType(BaseModel): + question: str chain = ( - RunnablePassthrough.assign(query=sql_response_memory) + RunnablePassthrough.assign(query=sql_response_memory).with_types(input_type=InputType) | RunnablePassthrough.assign( schema=get_schema, response=lambda x: db.run(x["query"]), diff --git a/templates/sql-ollama/README.md b/templates/sql-ollama/README.md index 67d5aeba759..c7b9769f3d9 100644 --- a/templates/sql-ollama/README.md +++ b/templates/sql-ollama/README.md @@ -15,11 +15,4 @@ Also follow instructions to download your LLM of interest: This template includes an example DB of 2023 NBA rosters. -You can see instructions to build this DB [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/StructuredLlama.ipynb). - -## Installation - -```bash -# from inside your LangServe instance -poe add sql-ollama -``` +You can see instructions to build this DB [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/StructuredLlama.ipynb). \ No newline at end of file diff --git a/templates/sql-ollama/nba_roster.db b/templates/sql-ollama/nba_roster.db deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/templates/sql-ollama/sql-ollama.ipynb b/templates/sql-ollama/sql-ollama.ipynb new file mode 100644 index 00000000000..e41c07e8fe0 --- /dev/null +++ b/templates/sql-ollama/sql-ollama.ipynb @@ -0,0 +1,54 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d55f5fd9-21eb-433d-9259-0a588d9197c0", + "metadata": {}, + "source": [ + "## Run Template\n", + "\n", + "In `server.py`, set -\n", + "```\n", + "add_routes(app, chain, path=\"/sql_ollama\")\n", + "```\n", + "\n", + "This template includes an example DB of 2023 NBA rosters.\n", + "\n", + "We can ask questions related to NBA players. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50c27e82-92d8-4fa1-8bc4-b6544e59773d", + "metadata": {}, + "outputs": [], + "source": [ + "from langserve.client import RemoteRunnable\n", + "sql_app = RemoteRunnable('http://0.0.0.0:8001/sql_ollama')\n", + "sql_app.invoke({\"question\": \"What team is Klay Thompson on?\"})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/templates/sql-ollama/sql_ollama/chain.py b/templates/sql-ollama/sql_ollama/chain.py index 7cf55746b88..45e5b2b3e4a 100644 --- a/templates/sql-ollama/sql_ollama/chain.py +++ b/templates/sql-ollama/sql_ollama/chain.py @@ -3,6 +3,7 @@ from pathlib import Path from langchain.chat_models import ChatOllama from langchain.memory import ConversationBufferMemory from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.pydantic_v1 import BaseModel from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableLambda, RunnablePassthrough from langchain.utilities import SQLDatabase @@ -82,8 +83,12 @@ prompt_response = ChatPromptTemplate.from_messages( ] ) +# Supply the input types to the prompt +class InputType(BaseModel): + question: str + chain = ( - RunnablePassthrough.assign(query=sql_response_memory) + RunnablePassthrough.assign(query=sql_response_memory).with_types(input_type=InputType) | RunnablePassthrough.assign( schema=get_schema, response=lambda x: db.run(x["query"]),