mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-26 13:21:40 +00:00
Currently the chain is getting the column names and types on the one side and the example rows on the other. It is easier for the llm to read the table information if the column name and examples are shown together so that it can easily understand to which columns do the examples refer to. For an instantiation of this, please refer to the changes in the `sqlite.ipynb` notebook. Also changed `eval` for `ast.literal_eval` when interpreting the results from the sample row query since it is a better practice. --------- Co-authored-by: Francisco Ingham <> --------- Co-authored-by: Francisco Ingham <fpingham@gmail.com>
This commit is contained in:
@@ -1,11 +1,25 @@
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
_TEMPLATE_PREFIX = """Table data will be described in the following format:
|
||||
|
||||
Table 'table name' has columns: {
|
||||
column1 name: (column1 type, [list of example values for column1]),
|
||||
column2 name: (column2 type, [list of example values for column2]),
|
||||
...
|
||||
}
|
||||
|
||||
These are the tables you can use, together with their column information:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SQLDatabase:
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
@@ -77,38 +91,33 @@ class SQLDatabase:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
|
||||
template = "Table '{table_name}' has columns: {columns}."
|
||||
|
||||
tables = []
|
||||
for table_name in all_table_names:
|
||||
columns = []
|
||||
columns = defaultdict(list)
|
||||
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
||||
columns.append(f"{column['name']} ({str(column['type'])})")
|
||||
column_str = ", ".join(columns)
|
||||
table_str = template.format(table_name=table_name, columns=column_str)
|
||||
columns[f"{column['name']}"].append(str(column["type"]))
|
||||
|
||||
if self._sample_rows_in_table_info:
|
||||
row_template = (
|
||||
" Here is an example of {n_rows} rows from this table "
|
||||
"(long strings are truncated):\n"
|
||||
"{sample_rows}"
|
||||
)
|
||||
sample_rows = self.run(
|
||||
f"SELECT * FROM '{table_name}' LIMIT "
|
||||
f"{self._sample_rows_in_table_info}"
|
||||
)
|
||||
sample_rows = eval(sample_rows)
|
||||
if len(sample_rows) > 0:
|
||||
n_rows = len(sample_rows)
|
||||
sample_rows = "\n".join(
|
||||
[" ".join([str(i)[:100] for i in row]) for row in sample_rows]
|
||||
)
|
||||
table_str += row_template.format(
|
||||
n_rows=n_rows, sample_rows=sample_rows
|
||||
|
||||
sample_rows_ls = ast.literal_eval(sample_rows)
|
||||
sample_rows_ls = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
|
||||
)
|
||||
|
||||
for e, col in enumerate(columns):
|
||||
columns[col].append(
|
||||
[row[e] for row in sample_rows_ls] # type: ignore
|
||||
)
|
||||
|
||||
table_str = f"Table '{table_name}' has columns: " + str(dict(columns))
|
||||
tables.append(table_str)
|
||||
return "\n".join(tables)
|
||||
|
||||
final_str = _TEMPLATE_PREFIX + "\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
Reference in New Issue
Block a user