mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Currently the chain is getting the column names and types on the one side and the example rows on the other. It is easier for the llm to read the table information if the column name and examples are shown together so that it can easily understand to which columns do the examples refer to. For an instantiation of this, please refer to the changes in the `sqlite.ipynb` notebook. Also changed `eval` for `ast.literal_eval` when interpreting the results from the sample row query since it is a better practice. --------- Co-authored-by: Francisco Ingham <> --------- Co-authored-by: Francisco Ingham <fpingham@gmail.com>
This commit is contained in:
parent
8c45f06d58
commit
ec727bf166
@ -287,14 +287,14 @@
|
||||
"What are some example tracks by composer Johann Sebastian Bach? \n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3;\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach')]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by Johann Sebastian Bach include 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m Examples of tracks by composer Johann Sebastian Bach are 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', and 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude'.\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Examples of tracks by Johann Sebastian Bach include \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'"
|
||||
"' Examples of tracks by composer Johann Sebastian Bach are \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', and \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\'.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
@ -317,13 +317,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"id": "9a22ee47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = SQLDatabase.from_uri(\n",
|
||||
" \"sqlite:///../../../../notebooks/Chinook.db\", \n",
|
||||
" \"sqlite:///../../../../notebooks/Chinook.db\",\n",
|
||||
" include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n",
|
||||
" sample_rows_in_table_info=2)"
|
||||
]
|
||||
@ -338,7 +338,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"id": "9de86267",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -346,9 +346,15 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example of 2 rows from this table (long strings are truncated):\n",
|
||||
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
|
||||
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n"
|
||||
"\n",
|
||||
" Table data will be described in the following format:\n",
|
||||
"\n",
|
||||
" Table 'table name' has columns: {column1 name: (column1 type, [list of example values for column1]),\n",
|
||||
" column2 name: (column2 type, [list of example values for column2], ...)\n",
|
||||
"\n",
|
||||
" These are the tables you can use, together with their column information:\n",
|
||||
"\n",
|
||||
" Table 'Track' has columns: {'TrackId': ['INTEGER', ['1', '2']], 'Name': ['NVARCHAR(200)', ['For Those About To Rock (We Salute You)', 'Balls to the Wall']], 'AlbumId': ['INTEGER', ['1', '2']], 'MediaTypeId': ['INTEGER', ['1', '2']], 'GenreId': ['INTEGER', ['1', '1']], 'Composer': ['NVARCHAR(220)', ['Angus Young, Malcolm Young, Brian Johnson', 'None']], 'Milliseconds': ['INTEGER', ['343719', '342562']], 'Bytes': ['INTEGER', ['11170334', '5510424']], 'UnitPrice': ['NUMERIC(10, 2)', ['0.99', '0.99']]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -358,7 +364,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 15,
|
||||
"id": "bcb7a489",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -368,7 +374,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 16,
|
||||
"id": "81e05d82",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -380,8 +386,8 @@
|
||||
"\n",
|
||||
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
|
||||
"What are some example tracks by Bach? \n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name, Composer FROM Track WHERE Composer LIKE '%Bach%' LIMIT 5;\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[('American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach'), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata', 'Johann Sebastian Bach')]\u001b[0m\n",
|
||||
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT Name FROM Track WHERE Composer LIKE '%Bach%' LIMIT 5;\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[('American Woman',), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata',)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', and 'Toccata and Fugue in D Minor, BWV 565: I. Toccata'.\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
@ -392,7 +398,7 @@
|
||||
"' Some example tracks by Bach are \\'American Woman\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\\', and \\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\\'.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -500,7 +506,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.2"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -15,7 +15,7 @@ SQLQuery: "SQL Query to run"
|
||||
SQLResult: "Result of the SQLQuery"
|
||||
Answer: "Final answer here"
|
||||
|
||||
Only use the following tables:
|
||||
Only use the tables listed below.
|
||||
|
||||
{table_info}
|
||||
|
||||
|
@ -1,11 +1,25 @@
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
_TEMPLATE_PREFIX = """Table data will be described in the following format:
|
||||
|
||||
Table 'table name' has columns: {
|
||||
column1 name: (column1 type, [list of example values for column1]),
|
||||
column2 name: (column2 type, [list of example values for column2]),
|
||||
...
|
||||
}
|
||||
|
||||
These are the tables you can use, together with their column information:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SQLDatabase:
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
@ -77,38 +91,33 @@ class SQLDatabase:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
|
||||
template = "Table '{table_name}' has columns: {columns}."
|
||||
|
||||
tables = []
|
||||
for table_name in all_table_names:
|
||||
columns = []
|
||||
columns = defaultdict(list)
|
||||
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
||||
columns.append(f"{column['name']} ({str(column['type'])})")
|
||||
column_str = ", ".join(columns)
|
||||
table_str = template.format(table_name=table_name, columns=column_str)
|
||||
columns[f"{column['name']}"].append(str(column["type"]))
|
||||
|
||||
if self._sample_rows_in_table_info:
|
||||
row_template = (
|
||||
" Here is an example of {n_rows} rows from this table "
|
||||
"(long strings are truncated):\n"
|
||||
"{sample_rows}"
|
||||
)
|
||||
sample_rows = self.run(
|
||||
f"SELECT * FROM '{table_name}' LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
sample_rows = eval(sample_rows)
|
||||
if len(sample_rows) > 0:
|
||||
n_rows = len(sample_rows)
|
||||
sample_rows = "\n".join(
|
||||
[" ".join([str(i)[:100] for i in row]) for row in sample_rows]
|
||||
)
|
||||
table_str += row_template.format(
|
||||
n_rows=n_rows, sample_rows=sample_rows
|
||||
|
||||
sample_rows_ls = ast.literal_eval(sample_rows)
|
||||
sample_rows_ls = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
|
||||
)
|
||||
|
||||
for e, col in enumerate(columns):
|
||||
columns[col].append(
|
||||
[row[e] for row in sample_rows_ls] # type: ignore
|
||||
)
|
||||
|
||||
table_str = f"Table '{table_name}' has columns: " + str(dict(columns))
|
||||
tables.append(table_str)
|
||||
return "\n".join(tables)
|
||||
|
||||
final_str = _TEMPLATE_PREFIX + "\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
@ -1,8 +1,9 @@
|
||||
# flake8: noqa=E501
|
||||
"""Test SQL database wrapper."""
|
||||
|
||||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert
|
||||
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
@ -27,10 +28,10 @@ def test_table_info() -> None:
|
||||
metadata_obj.create_all(engine)
|
||||
db = SQLDatabase(engine)
|
||||
output = db.table_info
|
||||
output = output[len(_TEMPLATE_PREFIX) :]
|
||||
expected_output = (
|
||||
"Table 'company' has columns: company_id (INTEGER), "
|
||||
"company_location (VARCHAR).",
|
||||
"Table 'user' has columns: user_id (INTEGER), user_name (VARCHAR(16)).",
|
||||
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR(16)']}",
|
||||
"Table 'company' has columns: {'company_id': ['INTEGER'], 'company_location': ['VARCHAR']}",
|
||||
)
|
||||
assert sorted(output.split("\n")) == sorted(expected_output)
|
||||
|
||||
@ -50,14 +51,12 @@ def test_table_info_w_sample_rows() -> None:
|
||||
db = SQLDatabase(engine, sample_rows_in_table_info=2)
|
||||
|
||||
output = db.table_info
|
||||
output = output[len(_TEMPLATE_PREFIX) :]
|
||||
expected_output = (
|
||||
"Table 'company' has columns: company_id (INTEGER), "
|
||||
"company_location (VARCHAR).\n"
|
||||
"Table 'user' has columns: user_id (INTEGER), "
|
||||
"user_name (VARCHAR(16)). Here is an example of 2 rows "
|
||||
"from this table (long strings are truncated):\n13 Harrison\n14 Chase"
|
||||
"Table 'user' has columns: {'user_id': ['INTEGER', ['13', '14']], 'user_name': ['VARCHAR(16)', ['Harrison', 'Chase']]}",
|
||||
"Table 'company' has columns: {'company_id': ['INTEGER', []], 'company_location': ['VARCHAR', []]}",
|
||||
)
|
||||
assert sorted(output.split("\n")) == sorted(expected_output.split("\n"))
|
||||
assert sorted(output.split("\n")) == sorted(expected_output)
|
||||
|
||||
|
||||
def test_sql_database_run() -> None:
|
||||
|
@ -16,7 +16,7 @@ from sqlalchemy import (
|
||||
schema,
|
||||
)
|
||||
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
|
||||
|
||||
metadata_obj = MetaData()
|
||||
|
||||
@ -46,10 +46,11 @@ def test_table_info() -> None:
|
||||
metadata_obj.create_all(engine)
|
||||
db = SQLDatabase(engine, schema="schema_a")
|
||||
output = db.table_info
|
||||
output = output[len(_TEMPLATE_PREFIX) :]
|
||||
expected_output = (
|
||||
"Table 'user' has columns: user_id (INTEGER), user_name (VARCHAR).",
|
||||
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR']}"
|
||||
)
|
||||
assert sorted(output.split("\n")) == sorted(expected_output)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sql_database_run() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user