mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
add sql database (#35)
This commit is contained in:
parent
90a6e578bc
commit
af81e9ca9c
93
examples/sqlite.ipynb
Normal file
93
examples/sqlite.ipynb
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b2f66479",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This uses the example Chinook database.\n",
|
||||||
|
"To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d0e27d88",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain import OpenAI, SQLDatabase, SQLDatabaseChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "72ede462",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"db = SQLDatabase.from_uri(\"sqlite:///../notebooks/Chinook.db\")\n",
|
||||||
|
"llm = OpenAI(temperature=0)\n",
|
||||||
|
"db_chain = SQLDatabaseChain(llm=llm, database=db)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "15ff81df",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" SELECT COUNT(*) FROM Employee\n",
|
||||||
|
"[(8,)]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"' There are 8 employees.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"db_chain.query(\"How many employees are there?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "146fa162",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -12,10 +12,12 @@ from langchain.chains import (
|
|||||||
ReActChain,
|
ReActChain,
|
||||||
SelfAskWithSearchChain,
|
SelfAskWithSearchChain,
|
||||||
SerpAPIChain,
|
SerpAPIChain,
|
||||||
|
SQLDatabaseChain,
|
||||||
)
|
)
|
||||||
from langchain.docstore import Wikipedia
|
from langchain.docstore import Wikipedia
|
||||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||||
from langchain.prompt import Prompt
|
from langchain.prompt import Prompt
|
||||||
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLMChain",
|
"LLMChain",
|
||||||
@ -29,4 +31,6 @@ __all__ = [
|
|||||||
"ReActChain",
|
"ReActChain",
|
||||||
"Wikipedia",
|
"Wikipedia",
|
||||||
"HuggingFaceHub",
|
"HuggingFaceHub",
|
||||||
|
"SQLDatabase",
|
||||||
|
"SQLDatabaseChain",
|
||||||
]
|
]
|
||||||
|
@ -5,6 +5,7 @@ from langchain.chains.python import PythonChain
|
|||||||
from langchain.chains.react.base import ReActChain
|
from langchain.chains.react.base import ReActChain
|
||||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||||
from langchain.chains.serpapi import SerpAPIChain
|
from langchain.chains.serpapi import SerpAPIChain
|
||||||
|
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLMChain",
|
"LLMChain",
|
||||||
@ -13,4 +14,5 @@ __all__ = [
|
|||||||
"SelfAskWithSearchChain",
|
"SelfAskWithSearchChain",
|
||||||
"SerpAPIChain",
|
"SerpAPIChain",
|
||||||
"ReActChain",
|
"ReActChain",
|
||||||
|
"SQLDatabaseChain",
|
||||||
]
|
]
|
||||||
|
1
langchain/chains/sql_database/__init__.py
Normal file
1
langchain/chains/sql_database/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Chain for interacting with SQL Database."""
|
85
langchain/chains/sql_database/base.py
Normal file
85
langchain/chains/sql_database/base.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
"""Chain for interacting with SQL Database."""
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.sql_database.prompt import PROMPT
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class SQLDatabaseChain(Chain, BaseModel):
|
||||||
|
"""Chain for interacting with SQL Database.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
|
||||||
|
db = SQLDatabase(...)
|
||||||
|
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: LLM
|
||||||
|
"""LLM wrapper to use."""
|
||||||
|
database: SQLDatabase
|
||||||
|
"""SQL Database to connect to."""
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the singular input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return the singular output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||||
|
_input = inputs[self.input_key] + "\nSQLQuery:"
|
||||||
|
llm_inputs = {
|
||||||
|
"input": _input,
|
||||||
|
"dialect": self.database.dialect,
|
||||||
|
"table_info": self.database.table_info,
|
||||||
|
"stop": ["\nSQLResult:"],
|
||||||
|
}
|
||||||
|
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||||
|
print(sql_cmd)
|
||||||
|
result = self.database.run(sql_cmd)
|
||||||
|
print(result)
|
||||||
|
_input += f"\nSQLResult: {result}\nAnswer:"
|
||||||
|
llm_inputs["input"] = _input
|
||||||
|
final_result = llm_chain.predict(**llm_inputs)
|
||||||
|
return {self.output_key: final_result}
|
||||||
|
|
||||||
|
def query(self, query: str) -> str:
|
||||||
|
"""Run natural language query against a SQL database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: natural language query to run against the SQL database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final answer as derived from the SQL database.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
answer = db_chain.query("How many customers are there?")
|
||||||
|
"""
|
||||||
|
return self({self.input_key: query})[self.output_key]
|
20
langchain/chains/sql_database/prompt.py
Normal file
20
langchain/chains/sql_database/prompt.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from langchain.prompt import Prompt
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||||
|
Use the following format:
|
||||||
|
|
||||||
|
Question: "Question here"
|
||||||
|
SQLQuery: "SQL Query to run"
|
||||||
|
SQLResult: "Result of the SQLQuery"
|
||||||
|
Answer: "Final answer here"
|
||||||
|
|
||||||
|
Only use the following tables:
|
||||||
|
|
||||||
|
{table_info}
|
||||||
|
|
||||||
|
Question: {input}"""
|
||||||
|
PROMPT = Prompt(
|
||||||
|
input_variables=["input", "table_info", "dialect"],
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
)
|
41
langchain/sql_database.py
Normal file
41
langchain/sql_database.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
from sqlalchemy import create_engine, inspect
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
|
class SQLDatabase:
|
||||||
|
"""SQLAlchemy wrapper around a database."""
|
||||||
|
|
||||||
|
def __init__(self, engine: Engine):
|
||||||
|
"""Create engine from database URI."""
|
||||||
|
self._engine = engine
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_uri(cls, database_uri: str) -> "SQLDatabase":
|
||||||
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
|
return cls(create_engine(database_uri))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dialect(self) -> str:
|
||||||
|
"""Return string representation of dialect to use."""
|
||||||
|
return self._engine.dialect.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def table_info(self) -> str:
|
||||||
|
"""Information about all tables in the database."""
|
||||||
|
template = "The '{table_name}' table has columns: {columns}."
|
||||||
|
tables = []
|
||||||
|
inspector = inspect(self._engine)
|
||||||
|
for table_name in inspector.get_table_names():
|
||||||
|
columns = []
|
||||||
|
for column in inspector.get_columns(table_name):
|
||||||
|
columns.append(f"{column['name']} ({str(column['type'])})")
|
||||||
|
column_str = ", ".join(columns)
|
||||||
|
table_str = template.format(table_name=table_name, columns=column_str)
|
||||||
|
tables.append(table_str)
|
||||||
|
return "\n".join(tables)
|
||||||
|
|
||||||
|
def run(self, command: str) -> str:
|
||||||
|
"""Execute a SQL command and return a string of the results."""
|
||||||
|
result = self._engine.execute(command).fetchall()
|
||||||
|
return str(result)
|
2
setup.py
2
setup.py
@ -14,7 +14,7 @@ setup(
|
|||||||
version=__version__,
|
version=__version__,
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
description="Building applications with LLMs through composability",
|
description="Building applications with LLMs through composability",
|
||||||
install_requires=["pydantic"],
|
install_requires=["pydantic", "sqlalchemy"],
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
url="https://github.com/hwchase17/langchain",
|
url="https://github.com/hwchase17/langchain",
|
||||||
|
30
tests/integration_tests/chains/test_sql_database.py
Normal file
30
tests/integration_tests/chains/test_sql_database.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Test SQL Database Chain."""
|
||||||
|
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
||||||
|
|
||||||
|
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
|
user = Table(
|
||||||
|
"user",
|
||||||
|
metadata_obj,
|
||||||
|
Column("user_id", Integer, primary_key=True),
|
||||||
|
Column("user_name", String(16), nullable=False),
|
||||||
|
Column("user_company", String(16), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sql_database_run() -> None:
|
||||||
|
"""Test that commands can be run successfully and returned in correct format."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
metadata_obj.create_all(engine)
|
||||||
|
stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo")
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(stmt)
|
||||||
|
db = SQLDatabase(engine)
|
||||||
|
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
|
||||||
|
output = db_chain.query("What company does Harrison work at?")
|
||||||
|
expected_output = " Harrison works at Foo."
|
||||||
|
assert output == expected_output
|
49
tests/unit_tests/test_sql_database.py
Normal file
49
tests/unit_tests/test_sql_database.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""Test SQL database wrapper."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
||||||
|
|
||||||
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
|
metadata_obj = MetaData()
|
||||||
|
|
||||||
|
user = Table(
|
||||||
|
"user",
|
||||||
|
metadata_obj,
|
||||||
|
Column("user_id", Integer, primary_key=True),
|
||||||
|
Column("user_name", String(16), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
company = Table(
|
||||||
|
"company",
|
||||||
|
metadata_obj,
|
||||||
|
Column("company_id", Integer, primary_key=True),
|
||||||
|
Column("company_location", String, nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_table_info() -> None:
|
||||||
|
"""Test that table info is constructed properly."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
metadata_obj.create_all(engine)
|
||||||
|
db = SQLDatabase(engine)
|
||||||
|
output = db.table_info
|
||||||
|
expected_output = (
|
||||||
|
"The 'company' table has columns: company_id (INTEGER), "
|
||||||
|
"company_location (VARCHAR).\n"
|
||||||
|
"The 'user' table has columns: user_id (INTEGER), user_name (VARCHAR(16))."
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_sql_database_run() -> None:
|
||||||
|
"""Test that commands can be run successfully and returned in correct format."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
metadata_obj.create_all(engine)
|
||||||
|
stmt = insert(user).values(user_id=13, user_name="Harrison")
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(stmt)
|
||||||
|
db = SQLDatabase(engine)
|
||||||
|
command = "select user_name from user where user_id = 13"
|
||||||
|
output = db.run(command)
|
||||||
|
expected_output = "[('Harrison',)]"
|
||||||
|
assert output == expected_output
|
Loading…
Reference in New Issue
Block a user