Sample row in table info for SQLDatabase (#769) (#782)

The agents usually benefit from understanding what the data looks like
to be able to filter effectively. Sending just one row in the table info
allows the agent to understand the data before querying and get better
results.

---------

Co-authored-by: Francisco Ingham <>

---------

Co-authored-by: Francisco Ingham <fpingham@gmail.com>
This commit is contained in:
Harrison Chase
2023-01-28 13:37:07 -08:00
committed by GitHub
parent 213c2e33e5
commit 248c297f1b
3 changed files with 107 additions and 10 deletions

View File

@@ -16,6 +16,7 @@ class SQLDatabase:
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_row_in_table_info: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
@@ -39,6 +40,7 @@ class SQLDatabase:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
self._sample_row_in_table_info = sample_row_in_table_info
@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
@@ -69,14 +71,28 @@ class SQLDatabase:
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
template = "Table '{table_name}' has columns: {columns}."
tables = []
for table_name in all_table_names:
columns = []
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns.append(f"{column['name']} ({str(column['type'])})")
column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str)
if self._sample_row_in_table_info:
row_template = (
" Here is an example row for this table"
" (long strings are truncated): {sample_row}."
)
sample_row = self.run(f"SELECT * FROM '{table_name}' LIMIT 1")
if len(eval(sample_row)) > 0:
sample_row = " ".join([str(i)[:100] for i in eval(sample_row)[0]])
table_str += row_template.format(sample_row=sample_row)
tables.append(table_str)
return "\n".join(tables)