From 4278046329691013c61753c2ced14ffd65502e10 Mon Sep 17 00:00:00 2001 From: Syed Baqar Abbas <76920434+SyedBaqarAbbas@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:00:03 +0500 Subject: [PATCH] [fix] Convert table names to list for compatibility in SQLDatabase (#29229) - [langchain_community.utilities.SQLDatabase] **[fix] Convert table names to list for compatibility in SQLDatabase**: - The issue #29227 is being fixed here - The "package" modified is community - The issue lied in this block of code: https://github.com/langchain-ai/langchain/blob/44b41b699c3815206413125ec58b6cca601ee438/libs/community/langchain_community/utilities/sql_database.py#L72-L77 - [langchain_community.utilities.SQLDatabase] **[fix] Convert table names to list for compatibility in SQLDatabase**: - **Description:** When the SQLDatabase is initialized, it runs a code `self._inspector.get_table_names(schema=schema)` which expects an output of list. However, with some connectors (such as snowflake) the data type returned could be another iterable. This results in a type error when concatenating the table_names to view_names. I have added explicit type casting to prevent this. - **Issue:** The issue #29227 is being fixed here - **Dependencies:** None - **Twitter handle:** @BaqarAbbas2001 ## Additional Information When the following method is called for a Snowflake database: https://github.com/langchain-ai/langchain/blob/44b41b699c3815206413125ec58b6cca601ee438/libs/community/langchain_community/utilities/sql_database.py#L75 Snowflake under the hood calls: ```python from snowflake.sqlalchemy.snowdialect import SnowflakeDialect SnowflakeDialect.get_table_names ``` This method returns a `dict_keys()` object which is incompatible to concatenate with a list and results in a `TypeError` ### Relevant Library Versions - **snowflake-sqlalchemy**: 1.7.2 - **snowflake-connector-python**: 3.12.4 - **sqlalchemy**: 2.0.20 - **langchain_community**: 0.3.14 --- libs/community/langchain_community/utilities/sql_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index 85f0a107677..2a091fd2cf3 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -72,7 +72,7 @@ class SQLDatabase: # including view support by adding the views as well as tables to the all # tables list if view_support is True self._all_tables = set( - self._inspector.get_table_names(schema=schema) + list(self._inspector.get_table_names(schema=schema)) + (self._inspector.get_view_names(schema=schema) if view_support else []) )