From af81e9ca9c54c5039368ca11f5f63007f2ad04b0 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 27 Oct 2022 23:21:47 -0700 Subject: [PATCH] add sql database (#35) --- examples/sqlite.ipynb | 93 +++++++++++++++++++ langchain/__init__.py | 4 + langchain/chains/__init__.py | 2 + langchain/chains/sql_database/__init__.py | 1 + langchain/chains/sql_database/base.py | 85 +++++++++++++++++ langchain/chains/sql_database/prompt.py | 20 ++++ langchain/sql_database.py | 41 ++++++++ setup.py | 2 +- .../chains/test_sql_database.py | 30 ++++++ tests/unit_tests/test_sql_database.py | 49 ++++++++++ 10 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 examples/sqlite.ipynb create mode 100644 langchain/chains/sql_database/__init__.py create mode 100644 langchain/chains/sql_database/base.py create mode 100644 langchain/chains/sql_database/prompt.py create mode 100644 langchain/sql_database.py create mode 100644 tests/integration_tests/chains/test_sql_database.py create mode 100644 tests/unit_tests/test_sql_database.py diff --git a/examples/sqlite.ipynb b/examples/sqlite.ipynb new file mode 100644 index 00000000000..39eb016deb6 --- /dev/null +++ b/examples/sqlite.ipynb @@ -0,0 +1,93 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b2f66479", + "metadata": {}, + "source": [ + "This uses the example Chinook database.\n", + "To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d0e27d88", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import OpenAI, SQLDatabase, SQLDatabaseChain" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "72ede462", + "metadata": {}, + "outputs": [], + "source": [ + "db = SQLDatabase.from_uri(\"sqlite:///../notebooks/Chinook.db\")\n", + "llm = OpenAI(temperature=0)\n", + "db_chain = SQLDatabaseChain(llm=llm, database=db)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "15ff81df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " SELECT COUNT(*) FROM Employee\n", + "[(8,)]\n" + ] + }, + { + "data": { + "text/plain": [ + "' There are 8 employees.'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "db_chain.query(\"How many employees are there?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "146fa162", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/__init__.py b/langchain/__init__.py index c0b4f18c870..ef77f3ffad2 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -12,10 +12,12 @@ from langchain.chains import ( ReActChain, SelfAskWithSearchChain, SerpAPIChain, + SQLDatabaseChain, ) from langchain.docstore import Wikipedia from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.prompt import Prompt +from langchain.sql_database import SQLDatabase __all__ = [ "LLMChain", @@ -29,4 +31,6 @@ __all__ = [ "ReActChain", "Wikipedia", "HuggingFaceHub", + "SQLDatabase", + "SQLDatabaseChain", ] diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 2de0cb6ee7b..adde1ed8cc1 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -5,6 +5,7 @@ from langchain.chains.python import PythonChain from langchain.chains.react.base import ReActChain from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain from langchain.chains.serpapi import SerpAPIChain +from langchain.chains.sql_database.base import SQLDatabaseChain __all__ = [ "LLMChain", @@ -13,4 +14,5 @@ __all__ = [ "SelfAskWithSearchChain", "SerpAPIChain", "ReActChain", + "SQLDatabaseChain", ] diff --git a/langchain/chains/sql_database/__init__.py b/langchain/chains/sql_database/__init__.py new file mode 100644 index 00000000000..b704f72c280 --- /dev/null +++ b/langchain/chains/sql_database/__init__.py @@ -0,0 +1 @@ +"""Chain for interacting with SQL Database.""" diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py new file mode 100644 index 00000000000..c6392273d8a --- /dev/null +++ b/langchain/chains/sql_database/base.py @@ -0,0 +1,85 @@ +"""Chain for interacting with SQL Database.""" +from typing import Dict, List + +from pydantic import BaseModel, Extra + +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.sql_database.prompt import PROMPT +from langchain.llms.base import LLM +from langchain.sql_database import SQLDatabase + + +class SQLDatabaseChain(Chain, BaseModel): + """Chain for interacting with SQL Database. + + Example: + .. code-block:: python + + from langchain import SQLDatabaseChain, OpenAI, SQLDatabase + db = SQLDatabase(...) + db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db) + """ + + llm: LLM + """LLM wrapper to use.""" + database: SQLDatabase + """SQL Database to connect to.""" + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Return the singular input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the singular output key. + + :meta private: + """ + return [self.output_key] + + def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: + llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) + _input = inputs[self.input_key] + "\nSQLQuery:" + llm_inputs = { + "input": _input, + "dialect": self.database.dialect, + "table_info": self.database.table_info, + "stop": ["\nSQLResult:"], + } + sql_cmd = llm_chain.predict(**llm_inputs) + print(sql_cmd) + result = self.database.run(sql_cmd) + print(result) + _input += f"\nSQLResult: {result}\nAnswer:" + llm_inputs["input"] = _input + final_result = llm_chain.predict(**llm_inputs) + return {self.output_key: final_result} + + def query(self, query: str) -> str: + """Run natural language query against a SQL database. + + Args: + query: natural language query to run against the SQL database + + Returns: + The final answer as derived from the SQL database. + + Example: + .. code-block:: python + + answer = db_chain.query("How many customers are there?") + """ + return self({self.input_key: query})[self.output_key] diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py new file mode 100644 index 00000000000..36d48d74e44 --- /dev/null +++ b/langchain/chains/sql_database/prompt.py @@ -0,0 +1,20 @@ +# flake8: noqa +from langchain.prompt import Prompt + +_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. +Use the following format: + +Question: "Question here" +SQLQuery: "SQL Query to run" +SQLResult: "Result of the SQLQuery" +Answer: "Final answer here" + +Only use the following tables: + +{table_info} + +Question: {input}""" +PROMPT = Prompt( + input_variables=["input", "table_info", "dialect"], + template=_DEFAULT_TEMPLATE, +) diff --git a/langchain/sql_database.py b/langchain/sql_database.py new file mode 100644 index 00000000000..138839bb3dc --- /dev/null +++ b/langchain/sql_database.py @@ -0,0 +1,41 @@ +"""SQLAlchemy wrapper around a database.""" +from sqlalchemy import create_engine, inspect +from sqlalchemy.engine import Engine + + +class SQLDatabase: + """SQLAlchemy wrapper around a database.""" + + def __init__(self, engine: Engine): + """Create engine from database URI.""" + self._engine = engine + + @classmethod + def from_uri(cls, database_uri: str) -> "SQLDatabase": + """Construct a SQLAlchemy engine from URI.""" + return cls(create_engine(database_uri)) + + @property + def dialect(self) -> str: + """Return string representation of dialect to use.""" + return self._engine.dialect.name + + @property + def table_info(self) -> str: + """Information about all tables in the database.""" + template = "The '{table_name}' table has columns: {columns}." + tables = [] + inspector = inspect(self._engine) + for table_name in inspector.get_table_names(): + columns = [] + for column in inspector.get_columns(table_name): + columns.append(f"{column['name']} ({str(column['type'])})") + column_str = ", ".join(columns) + table_str = template.format(table_name=table_name, columns=column_str) + tables.append(table_str) + return "\n".join(tables) + + def run(self, command: str) -> str: + """Execute a SQL command and return a string of the results.""" + result = self._engine.execute(command).fetchall() + return str(result) diff --git a/setup.py b/setup.py index b84bd60913d..4ac3f475dd9 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( version=__version__, packages=find_packages(), description="Building applications with LLMs through composability", - install_requires=["pydantic"], + install_requires=["pydantic", "sqlalchemy"], long_description=long_description, license="MIT", url="https://github.com/hwchase17/langchain", diff --git a/tests/integration_tests/chains/test_sql_database.py b/tests/integration_tests/chains/test_sql_database.py new file mode 100644 index 00000000000..f874cc25893 --- /dev/null +++ b/tests/integration_tests/chains/test_sql_database.py @@ -0,0 +1,30 @@ +"""Test SQL Database Chain.""" +from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert + +from langchain.chains.sql_database.base import SQLDatabaseChain +from langchain.llms.openai import OpenAI +from langchain.sql_database import SQLDatabase + +metadata_obj = MetaData() + +user = Table( + "user", + metadata_obj, + Column("user_id", Integer, primary_key=True), + Column("user_name", String(16), nullable=False), + Column("user_company", String(16), nullable=False), +) + + +def test_sql_database_run() -> None: + """Test that commands can be run successfully and returned in correct format.""" + engine = create_engine("sqlite:///:memory:") + metadata_obj.create_all(engine) + stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo") + with engine.connect() as conn: + conn.execute(stmt) + db = SQLDatabase(engine) + db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db) + output = db_chain.query("What company does Harrison work at?") + expected_output = " Harrison works at Foo." + assert output == expected_output diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py new file mode 100644 index 00000000000..c18d5deb63a --- /dev/null +++ b/tests/unit_tests/test_sql_database.py @@ -0,0 +1,49 @@ +"""Test SQL database wrapper.""" + +from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert + +from langchain.sql_database import SQLDatabase + +metadata_obj = MetaData() + +user = Table( + "user", + metadata_obj, + Column("user_id", Integer, primary_key=True), + Column("user_name", String(16), nullable=False), +) + +company = Table( + "company", + metadata_obj, + Column("company_id", Integer, primary_key=True), + Column("company_location", String, nullable=False), +) + + +def test_table_info() -> None: + """Test that table info is constructed properly.""" + engine = create_engine("sqlite:///:memory:") + metadata_obj.create_all(engine) + db = SQLDatabase(engine) + output = db.table_info + expected_output = ( + "The 'company' table has columns: company_id (INTEGER), " + "company_location (VARCHAR).\n" + "The 'user' table has columns: user_id (INTEGER), user_name (VARCHAR(16))." + ) + assert output == expected_output + + +def test_sql_database_run() -> None: + """Test that commands can be run successfully and returned in correct format.""" + engine = create_engine("sqlite:///:memory:") + metadata_obj.create_all(engine) + stmt = insert(user).values(user_id=13, user_name="Harrison") + with engine.connect() as conn: + conn.execute(stmt) + db = SQLDatabase(engine) + command = "select user_name from user where user_id = 13" + output = db.run(command) + expected_output = "[('Harrison',)]" + assert output == expected_output