more complex sql chain (#619)

add a more complex sql chain that first subsets the necessary tables
This commit is contained in:
Harrison Chase
2023-01-15 17:07:21 -08:00
committed by GitHub
parent 49b3d6c78c
commit 1c71fadfdc
6 changed files with 184 additions and 8 deletions

View File

@@ -50,7 +50,8 @@ class SQLDatabase:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def _get_table_names(self) -> Iterable[str]:
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return set(self._all_tables) - set(self._ignore_tables)
@@ -58,9 +59,19 @@ class SQLDatabase:
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
all_table_names = self.get_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
template = "Table '{table_name}' has columns: {columns}."
tables = []
for table_name in self._get_table_names():
for table_name in all_table_names:
columns = []
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns.append(f"{column['name']} ({str(column['type'])})")