mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-27 13:53:43 +00:00
more complex sql chain (#619)
add a more complex sql chain that first subsets the necessary tables
This commit is contained in:
@@ -50,7 +50,8 @@ class SQLDatabase:
|
||||
"""Return string representation of dialect to use."""
|
||||
return self._engine.dialect.name
|
||||
|
||||
def _get_table_names(self) -> Iterable[str]:
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
if self._include_tables:
|
||||
return self._include_tables
|
||||
return set(self._all_tables) - set(self._ignore_tables)
|
||||
@@ -58,9 +59,19 @@ class SQLDatabase:
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
"""Information about all tables in the database."""
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables."""
|
||||
all_table_names = self.get_table_names()
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
template = "Table '{table_name}' has columns: {columns}."
|
||||
tables = []
|
||||
for table_name in self._get_table_names():
|
||||
for table_name in all_table_names:
|
||||
columns = []
|
||||
for column in self._inspector.get_columns(table_name, schema=self._schema):
|
||||
columns.append(f"{column['name']} ({str(column['type'])})")
|
||||
|
||||
Reference in New Issue
Block a user