Align table info (#999) (#1034)

Currently the chain is getting the column names and types on the one
side and the example rows on the other. It is easier for the llm to read
the table information if the column name and examples are shown together
so that it can easily understand to which columns do the examples refer
to. For an instantiation of this, please refer to the changes in the
`sqlite.ipynb` notebook.

Also changed `eval` for `ast.literal_eval` when interpreting the results
from the sample row query since it is a better practice.

---------

Co-authored-by: Francisco Ingham <>

---------

Co-authored-by: Francisco Ingham <fpingham@gmail.com>
This commit is contained in:
Harrison Chase
2023-02-13 21:48:41 -08:00
committed by GitHub
parent 8c45f06d58
commit ec727bf166
5 changed files with 63 additions and 48 deletions

View File

@@ -1,11 +1,25 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
import ast
from collections import defaultdict
from typing import Any, Iterable, List, Optional
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine import Engine
_TEMPLATE_PREFIX = """Table data will be described in the following format:
Table 'table name' has columns: {
column1 name: (column1 type, [list of example values for column1]),
column2 name: (column2 type, [list of example values for column2]),
...
}
These are the tables you can use, together with their column information:
"""
class SQLDatabase:
"""SQLAlchemy wrapper around a database."""
@@ -77,38 +91,33 @@ class SQLDatabase:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
template = "Table '{table_name}' has columns: {columns}."
tables = []
for table_name in all_table_names:
columns = []
columns = defaultdict(list)
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns.append(f"{column['name']} ({str(column['type'])})")
column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str)
columns[f"{column['name']}"].append(str(column["type"]))
if self._sample_rows_in_table_info:
row_template = (
" Here is an example of {n_rows} rows from this table "
"(long strings are truncated):\n"
"{sample_rows}"
)
sample_rows = self.run(
f"SELECT * FROM '{table_name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)
sample_rows = eval(sample_rows)
if len(sample_rows) > 0:
n_rows = len(sample_rows)
sample_rows = "\n".join(
[" ".join([str(i)[:100] for i in row]) for row in sample_rows]
)
table_str += row_template.format(
n_rows=n_rows, sample_rows=sample_rows
sample_rows_ls = ast.literal_eval(sample_rows)
sample_rows_ls = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
)
for e, col in enumerate(columns):
columns[col].append(
[row[e] for row in sample_rows_ls] # type: ignore
)
table_str = f"Table '{table_name}' has columns: " + str(dict(columns))
tables.append(table_str)
return "\n".join(tables)
final_str = _TEMPLATE_PREFIX + "\n".join(tables)
return final_str
def run(self, command: str) -> str:
"""Execute a SQL command and return a string representing the results.