diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index fb542cfa883..a168e7b7fb7 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -58,6 +58,7 @@ class SQLDatabase: custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, + lazy_table_reflection: bool = False, ): """Create engine from database URI.""" self._engine = engine @@ -113,15 +114,17 @@ class SQLDatabase: ) self._max_string_length = max_string_length + self._view_support = view_support self._metadata = metadata or MetaData() - # including view support if view_support = true - self._metadata.reflect( - views=view_support, - bind=self._engine, - only=list(self._usable_tables), - schema=self._schema, - ) + if not lazy_table_reflection: + # including view support if view_support = true + self._metadata.reflect( + views=view_support, + bind=self._engine, + only=list(self._usable_tables), + schema=self._schema, + ) @classmethod def from_uri( @@ -307,6 +310,16 @@ class SQLDatabase: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names + metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables] + to_reflect = set(all_table_names) - set(metadata_table_names) + if to_reflect: + self._metadata.reflect( + views=self._view_support, + bind=self._engine, + only=list(to_reflect), + schema=self._schema, + ) + meta_tables = [ tbl for tbl in self._metadata.sorted_tables diff --git a/libs/community/tests/unit_tests/test_sql_database.py b/libs/community/tests/unit_tests/test_sql_database.py index 4293265a1de..da4c1ddbea6 100644 --- a/libs/community/tests/unit_tests/test_sql_database.py +++ b/libs/community/tests/unit_tests/test_sql_database.py @@ -45,6 +45,12 @@ def db(engine: sa.Engine) -> SQLDatabase: return SQLDatabase(engine) +@pytest.fixture +def db_lazy_reflection(engine: sa.Engine) -> SQLDatabase: + metadata_obj.create_all(engine) + return SQLDatabase(engine, lazy_table_reflection=True) + + def test_table_info(db: SQLDatabase) -> None: """Test that table info is constructed properly.""" output = db.table_info @@ -75,6 +81,32 @@ def test_table_info(db: SQLDatabase) -> None: assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) +def test_table_info_lazy_reflection(db_lazy_reflection: SQLDatabase) -> None: + """Test that table info with lazy reflection""" + assert len(db_lazy_reflection._metadata.sorted_tables) == 0 + output = db_lazy_reflection.get_table_info(["user"]) + assert len(db_lazy_reflection._metadata.sorted_tables) == 1 + expected_output = """ + CREATE TABLE user ( + user_id INTEGER NOT NULL, + user_name VARCHAR(16) NOT NULL, + user_bio TEXT, + PRIMARY KEY (user_id) + ) + /* + 3 rows from user table: + user_id user_name user_bio + /* + """ + + assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split())) + + db_lazy_reflection.get_table_info(["company"]) + assert len(db_lazy_reflection._metadata.sorted_tables) == 2 + assert db_lazy_reflection._metadata.sorted_tables[0].name == "company" + assert db_lazy_reflection._metadata.sorted_tables[1].name == "user" + + def test_table_info_w_sample_rows(db: SQLDatabase) -> None: """Test that table info is constructed properly."""