add sql database (#35)

This commit is contained in:
Harrison Chase 2022-10-27 23:21:47 -07:00 committed by GitHub
parent 90a6e578bc
commit af81e9ca9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 326 additions and 1 deletions

93
examples/sqlite.ipynb Normal file
View File

@ -0,0 +1,93 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b2f66479",
"metadata": {},
"source": [
"This uses the example Chinook database.\n",
"To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d0e27d88",
"metadata": {},
"outputs": [],
"source": [
"from langchain import OpenAI, SQLDatabase, SQLDatabaseChain"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "72ede462",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\"sqlite:///../notebooks/Chinook.db\")\n",
"llm = OpenAI(temperature=0)\n",
"db_chain = SQLDatabaseChain(llm=llm, database=db)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "15ff81df",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" SELECT COUNT(*) FROM Employee\n",
"[(8,)]\n"
]
},
{
"data": {
"text/plain": [
"' There are 8 employees.'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.query(\"How many employees are there?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "146fa162",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -12,10 +12,12 @@ from langchain.chains import (
ReActChain,
SelfAskWithSearchChain,
SerpAPIChain,
SQLDatabaseChain,
)
from langchain.docstore import Wikipedia
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompt import Prompt
from langchain.sql_database import SQLDatabase
__all__ = [
"LLMChain",
@ -29,4 +31,6 @@ __all__ = [
"ReActChain",
"Wikipedia",
"HuggingFaceHub",
"SQLDatabase",
"SQLDatabaseChain",
]

View File

@ -5,6 +5,7 @@ from langchain.chains.python import PythonChain
from langchain.chains.react.base import ReActChain
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.chains.sql_database.base import SQLDatabaseChain
__all__ = [
"LLMChain",
@ -13,4 +14,5 @@ __all__ = [
"SelfAskWithSearchChain",
"SerpAPIChain",
"ReActChain",
"SQLDatabaseChain",
]

View File

@ -0,0 +1 @@
"""Chain for interacting with SQL Database."""

View File

@ -0,0 +1,85 @@
"""Chain for interacting with SQL Database."""
from typing import Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT
from langchain.llms.base import LLM
from langchain.sql_database import SQLDatabase
class SQLDatabaseChain(Chain, BaseModel):
"""Chain for interacting with SQL Database.
Example:
.. code-block:: python
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
db = SQLDatabase(...)
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
"""
llm: LLM
"""LLM wrapper to use."""
database: SQLDatabase
"""SQL Database to connect to."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the singular output key.
:meta private:
"""
return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
_input = inputs[self.input_key] + "\nSQLQuery:"
llm_inputs = {
"input": _input,
"dialect": self.database.dialect,
"table_info": self.database.table_info,
"stop": ["\nSQLResult:"],
}
sql_cmd = llm_chain.predict(**llm_inputs)
print(sql_cmd)
result = self.database.run(sql_cmd)
print(result)
_input += f"\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = _input
final_result = llm_chain.predict(**llm_inputs)
return {self.output_key: final_result}
def query(self, query: str) -> str:
"""Run natural language query against a SQL database.
Args:
query: natural language query to run against the SQL database
Returns:
The final answer as derived from the SQL database.
Example:
.. code-block:: python
answer = db_chain.query("How many customers are there?")
"""
return self({self.input_key: query})[self.output_key]

View File

@ -0,0 +1,20 @@
# flake8: noqa
from langchain.prompt import Prompt
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
Question: {input}"""
PROMPT = Prompt(
input_variables=["input", "table_info", "dialect"],
template=_DEFAULT_TEMPLATE,
)

41
langchain/sql_database.py Normal file
View File

@ -0,0 +1,41 @@
"""SQLAlchemy wrapper around a database."""
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine import Engine
class SQLDatabase:
"""SQLAlchemy wrapper around a database."""
def __init__(self, engine: Engine):
"""Create engine from database URI."""
self._engine = engine
@classmethod
def from_uri(cls, database_uri: str) -> "SQLDatabase":
"""Construct a SQLAlchemy engine from URI."""
return cls(create_engine(database_uri))
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
template = "The '{table_name}' table has columns: {columns}."
tables = []
inspector = inspect(self._engine)
for table_name in inspector.get_table_names():
columns = []
for column in inspector.get_columns(table_name):
columns.append(f"{column['name']} ({str(column['type'])})")
column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str)
tables.append(table_str)
return "\n".join(tables)
def run(self, command: str) -> str:
"""Execute a SQL command and return a string of the results."""
result = self._engine.execute(command).fetchall()
return str(result)

View File

@ -14,7 +14,7 @@ setup(
version=__version__,
packages=find_packages(),
description="Building applications with LLMs through composability",
install_requires=["pydantic"],
install_requires=["pydantic", "sqlalchemy"],
long_description=long_description,
license="MIT",
url="https://github.com/hwchase17/langchain",

View File

@ -0,0 +1,30 @@
"""Test SQL Database Chain."""
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.llms.openai import OpenAI
from langchain.sql_database import SQLDatabase
metadata_obj = MetaData()
user = Table(
"user",
metadata_obj,
Column("user_id", Integer, primary_key=True),
Column("user_name", String(16), nullable=False),
Column("user_company", String(16), nullable=False),
)
def test_sql_database_run() -> None:
"""Test that commands can be run successfully and returned in correct format."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo")
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
output = db_chain.query("What company does Harrison work at?")
expected_output = " Harrison works at Foo."
assert output == expected_output

View File

@ -0,0 +1,49 @@
"""Test SQL database wrapper."""
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
from langchain.sql_database import SQLDatabase
metadata_obj = MetaData()
user = Table(
"user",
metadata_obj,
Column("user_id", Integer, primary_key=True),
Column("user_name", String(16), nullable=False),
)
company = Table(
"company",
metadata_obj,
Column("company_id", Integer, primary_key=True),
Column("company_location", String, nullable=False),
)
def test_table_info() -> None:
"""Test that table info is constructed properly."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
db = SQLDatabase(engine)
output = db.table_info
expected_output = (
"The 'company' table has columns: company_id (INTEGER), "
"company_location (VARCHAR).\n"
"The 'user' table has columns: user_id (INTEGER), user_name (VARCHAR(16))."
)
assert output == expected_output
def test_sql_database_run() -> None:
"""Test that commands can be run successfully and returned in correct format."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.connect() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = "select user_name from user where user_id = 13"
output = db.run(command)
expected_output = "[('Harrison',)]"
assert output == expected_output