mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
remove sample_row_in_table_info and simplify set operations in SQLDB (#932)
-Address TODO: deprecate for sample_row_in_table_info -Simplify set operations by casting to sets to not need multiple set casts + .difference() calls
This commit is contained in:
parent
e323d0cfb1
commit
512c523368
@ -17,40 +17,30 @@ class SQLDatabase:
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 0,
|
||||
# TODO: deprecate.
|
||||
sample_row_in_table_info: bool = False,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
if sample_row_in_table_info and sample_rows_in_table_info > 0:
|
||||
raise ValueError(
|
||||
"Only one of `sample_row_in_table_info` "
|
||||
"and `sample_rows_in_table_info` should be set"
|
||||
)
|
||||
self._engine = engine
|
||||
self._schema = schema
|
||||
if include_tables and ignore_tables:
|
||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||
|
||||
self._inspector = inspect(self._engine)
|
||||
self._all_tables = self._inspector.get_table_names(schema=schema)
|
||||
self._include_tables = include_tables or []
|
||||
self._all_tables = set(self._inspector.get_table_names(schema=schema))
|
||||
self._include_tables = set(include_tables) if include_tables else set()
|
||||
if self._include_tables:
|
||||
missing_tables = set(self._include_tables).difference(self._all_tables)
|
||||
missing_tables = self._include_tables - self._all_tables
|
||||
if missing_tables:
|
||||
raise ValueError(
|
||||
f"include_tables {missing_tables} not found in database"
|
||||
)
|
||||
self._ignore_tables = ignore_tables or []
|
||||
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
||||
if self._ignore_tables:
|
||||
missing_tables = set(self._ignore_tables).difference(self._all_tables)
|
||||
missing_tables = self._ignore_tables - self._all_tables
|
||||
if missing_tables:
|
||||
raise ValueError(
|
||||
f"ignore_tables {missing_tables} not found in database"
|
||||
)
|
||||
self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
# TODO: deprecate
|
||||
if sample_row_in_table_info:
|
||||
self._sample_rows_in_table_info = 1
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
|
||||
@ -66,7 +56,7 @@ class SQLDatabase:
|
||||
"""Get names of tables available."""
|
||||
if self._include_tables:
|
||||
return self._include_tables
|
||||
return set(self._all_tables) - set(self._ignore_tables)
|
||||
return self._all_tables - self._ignore_tables
|
||||
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user