mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
add sql database (#35)
This commit is contained in:
parent
90a6e578bc
commit
af81e9ca9c
93
examples/sqlite.ipynb
Normal file
93
examples/sqlite.ipynb
Normal file
@ -0,0 +1,93 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b2f66479",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This uses the example Chinook database.\n",
|
||||
"To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d0e27d88",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import OpenAI, SQLDatabase, SQLDatabaseChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "72ede462",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///../notebooks/Chinook.db\")\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"db_chain = SQLDatabaseChain(llm=llm, database=db)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "15ff81df",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" SELECT COUNT(*) FROM Employee\n",
|
||||
"[(8,)]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' There are 8 employees.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db_chain.query(\"How many employees are there?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "146fa162",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -12,10 +12,12 @@ from langchain.chains import (
|
||||
ReActChain,
|
||||
SelfAskWithSearchChain,
|
||||
SerpAPIChain,
|
||||
SQLDatabaseChain,
|
||||
)
|
||||
from langchain.docstore import Wikipedia
|
||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||
from langchain.prompt import Prompt
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
@ -29,4 +31,6 @@ __all__ = [
|
||||
"ReActChain",
|
||||
"Wikipedia",
|
||||
"HuggingFaceHub",
|
||||
"SQLDatabase",
|
||||
"SQLDatabaseChain",
|
||||
]
|
||||
|
@ -5,6 +5,7 @@ from langchain.chains.python import PythonChain
|
||||
from langchain.chains.react.base import ReActChain
|
||||
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
|
||||
from langchain.chains.serpapi import SerpAPIChain
|
||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
@ -13,4 +14,5 @@ __all__ = [
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIChain",
|
||||
"ReActChain",
|
||||
"SQLDatabaseChain",
|
||||
]
|
||||
|
1
langchain/chains/sql_database/__init__.py
Normal file
1
langchain/chains/sql_database/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Chain for interacting with SQL Database."""
|
85
langchain/chains/sql_database/base.py
Normal file
85
langchain/chains/sql_database/base.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""Chain for interacting with SQL Database."""
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import PROMPT
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
|
||||
class SQLDatabaseChain(Chain, BaseModel):
|
||||
"""Chain for interacting with SQL Database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
|
||||
db = SQLDatabase(...)
|
||||
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use."""
|
||||
database: SQLDatabase
|
||||
"""SQL Database to connect to."""
|
||||
input_key: str = "query" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
_input = inputs[self.input_key] + "\nSQLQuery:"
|
||||
llm_inputs = {
|
||||
"input": _input,
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_info,
|
||||
"stop": ["\nSQLResult:"],
|
||||
}
|
||||
sql_cmd = llm_chain.predict(**llm_inputs)
|
||||
print(sql_cmd)
|
||||
result = self.database.run(sql_cmd)
|
||||
print(result)
|
||||
_input += f"\nSQLResult: {result}\nAnswer:"
|
||||
llm_inputs["input"] = _input
|
||||
final_result = llm_chain.predict(**llm_inputs)
|
||||
return {self.output_key: final_result}
|
||||
|
||||
def query(self, query: str) -> str:
|
||||
"""Run natural language query against a SQL database.
|
||||
|
||||
Args:
|
||||
query: natural language query to run against the SQL database
|
||||
|
||||
Returns:
|
||||
The final answer as derived from the SQL database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
answer = db_chain.query("How many customers are there?")
|
||||
"""
|
||||
return self({self.input_key: query})[self.output_key]
|
20
langchain/chains/sql_database/prompt.py
Normal file
20
langchain/chains/sql_database/prompt.py
Normal file
@ -0,0 +1,20 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompt import Prompt
|
||||
|
||||
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
||||
Use the following format:
|
||||
|
||||
Question: "Question here"
|
||||
SQLQuery: "SQL Query to run"
|
||||
SQLResult: "Result of the SQLQuery"
|
||||
Answer: "Final answer here"
|
||||
|
||||
Only use the following tables:
|
||||
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
PROMPT = Prompt(
|
||||
input_variables=["input", "table_info", "dialect"],
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
)
|
41
langchain/sql_database.py
Normal file
41
langchain/sql_database.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
class SQLDatabase:
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(self, engine: Engine):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, database_uri: str) -> "SQLDatabase":
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
return cls(create_engine(database_uri))
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of dialect to use."""
|
||||
return self._engine.dialect.name
|
||||
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
"""Information about all tables in the database."""
|
||||
template = "The '{table_name}' table has columns: {columns}."
|
||||
tables = []
|
||||
inspector = inspect(self._engine)
|
||||
for table_name in inspector.get_table_names():
|
||||
columns = []
|
||||
for column in inspector.get_columns(table_name):
|
||||
columns.append(f"{column['name']} ({str(column['type'])})")
|
||||
column_str = ", ".join(columns)
|
||||
table_str = template.format(table_name=table_name, columns=column_str)
|
||||
tables.append(table_str)
|
||||
return "\n".join(tables)
|
||||
|
||||
def run(self, command: str) -> str:
|
||||
"""Execute a SQL command and return a string of the results."""
|
||||
result = self._engine.execute(command).fetchall()
|
||||
return str(result)
|
2
setup.py
2
setup.py
@ -14,7 +14,7 @@ setup(
|
||||
version=__version__,
|
||||
packages=find_packages(),
|
||||
description="Building applications with LLMs through composability",
|
||||
install_requires=["pydantic"],
|
||||
install_requires=["pydantic", "sqlalchemy"],
|
||||
long_description=long_description,
|
||||
license="MIT",
|
||||
url="https://github.com/hwchase17/langchain",
|
||||
|
30
tests/integration_tests/chains/test_sql_database.py
Normal file
30
tests/integration_tests/chains/test_sql_database.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Test SQL Database Chain."""
|
||||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
||||
|
||||
from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
user = Table(
|
||||
"user",
|
||||
metadata_obj,
|
||||
Column("user_id", Integer, primary_key=True),
|
||||
Column("user_name", String(16), nullable=False),
|
||||
Column("user_company", String(16), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def test_sql_database_run() -> None:
|
||||
"""Test that commands can be run successfully and returned in correct format."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=13, user_name="Harrison", user_company="Foo")
|
||||
with engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
|
||||
output = db_chain.query("What company does Harrison work at?")
|
||||
expected_output = " Harrison works at Foo."
|
||||
assert output == expected_output
|
49
tests/unit_tests/test_sql_database.py
Normal file
49
tests/unit_tests/test_sql_database.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Test SQL database wrapper."""
|
||||
|
||||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
||||
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
user = Table(
|
||||
"user",
|
||||
metadata_obj,
|
||||
Column("user_id", Integer, primary_key=True),
|
||||
Column("user_name", String(16), nullable=False),
|
||||
)
|
||||
|
||||
company = Table(
|
||||
"company",
|
||||
metadata_obj,
|
||||
Column("company_id", Integer, primary_key=True),
|
||||
Column("company_location", String, nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def test_table_info() -> None:
|
||||
"""Test that table info is constructed properly."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
db = SQLDatabase(engine)
|
||||
output = db.table_info
|
||||
expected_output = (
|
||||
"The 'company' table has columns: company_id (INTEGER), "
|
||||
"company_location (VARCHAR).\n"
|
||||
"The 'user' table has columns: user_id (INTEGER), user_name (VARCHAR(16))."
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sql_database_run() -> None:
|
||||
"""Test that commands can be run successfully and returned in correct format."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=13, user_name="Harrison")
|
||||
with engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
command = "select user_name from user where user_id = 13"
|
||||
output = db.run(command)
|
||||
expected_output = "[('Harrison',)]"
|
||||
assert output == expected_output
|
Loading…
Reference in New Issue
Block a user