From ae8bc9e8302e87fdd4d9a22eb47c91ce849eb431 Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Thu, 20 Jul 2023 22:17:55 -0700 Subject: [PATCH] Refactored `sql_database` (#7945) The `sql_database.py` is unnecessarily placed in the root code folder. A similar code is usually placed in the `utilities/`. As a byproduct of this placement, the sql_database is [placed on the top level of classes in the API Reference](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.sql_database) which is confusing and not correct. - moved the `sql_database.py` from the root code folder to the `utilities/` @baskaryan --------- Co-authored-by: Harrison Chase --- langchain/__init__.py | 2 +- .../agents/agent_toolkits/sql/toolkit.py | 2 +- langchain/chains/sql_database/base.py | 2 +- langchain/sql_database.py | 447 +----------------- langchain/tools/sql_database/tool.py | 2 +- langchain/utilities/__init__.py | 2 + langchain/utilities/sql_database.py | 445 +++++++++++++++++ .../chains/test_sql_database.py | 2 +- tests/unit_tests/agents/test_sql.py | 2 +- tests/unit_tests/test_sql_database.py | 2 +- tests/unit_tests/test_sql_database_schema.py | 2 +- 11 files changed, 458 insertions(+), 452 deletions(-) create mode 100644 langchain/utilities/sql_database.py diff --git a/langchain/__init__.py b/langchain/__init__.py index 0fe0260aa32..6b0c3f1351d 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -43,7 +43,6 @@ from langchain.prompts import ( PromptTemplate, ) from langchain.schema.prompt_template import BasePromptTemplate -from langchain.sql_database import SQLDatabase from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.golden_query import GoldenQueryAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper @@ -51,6 +50,7 @@ from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.powerbi import PowerBIDataset from langchain.utilities.searx_search import SearxSearchWrapper from langchain.utilities.serpapi import SerpAPIWrapper +from langchain.utilities.sql_database import SQLDatabase from langchain.utilities.wikipedia import WikipediaAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper from langchain.vectorstores import FAISS, ElasticVectorSearch diff --git a/langchain/agents/agent_toolkits/sql/toolkit.py b/langchain/agents/agent_toolkits/sql/toolkit.py index 5e5bd4cf75c..5fe6e078632 100644 --- a/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/langchain/agents/agent_toolkits/sql/toolkit.py @@ -5,7 +5,6 @@ from pydantic import Field from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.schema.language_model import BaseLanguageModel -from langchain.sql_database import SQLDatabase from langchain.tools import BaseTool from langchain.tools.sql_database.tool import ( InfoSQLDatabaseTool, @@ -13,6 +12,7 @@ from langchain.tools.sql_database.tool import ( QuerySQLCheckerTool, QuerySQLDataBaseTool, ) +from langchain.utilities.sql_database import SQLDatabase class SQLDatabaseToolkit(BaseToolkit): diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index c69fdf0b162..0ba4c4c5880 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -13,8 +13,8 @@ from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PRO from langchain.prompts.prompt import PromptTemplate from langchain.schema import BasePromptTemplate from langchain.schema.language_model import BaseLanguageModel -from langchain.sql_database import SQLDatabase from langchain.tools.sql_database.prompt import QUERY_CHECKER +from langchain.utilities.sql_database import SQLDatabase INTERMEDIATE_STEPS_KEY = "intermediate_steps" diff --git a/langchain/sql_database.py b/langchain/sql_database.py index 43a43e86fd9..47623c04e77 100644 --- a/langchain/sql_database.py +++ b/langchain/sql_database.py @@ -1,445 +1,4 @@ -"""SQLAlchemy wrapper around a database.""" -from __future__ import annotations +"""Keep here for backwards compatibility.""" +from langchain.utilities.sql_database import SQLDatabase -import warnings -from typing import Any, Iterable, List, Optional - -import sqlalchemy -from sqlalchemy import MetaData, Table, create_engine, inspect, select, text -from sqlalchemy.engine import Engine -from sqlalchemy.exc import ProgrammingError, SQLAlchemyError -from sqlalchemy.schema import CreateTable - -from langchain.utils import get_from_env - - -def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: - return ( - f'Name: {index["name"]}, Unique: {index["unique"]},' - f' Columns: {str(index["column_names"])}' - ) - - -def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: - """ - Truncate a string to a certain number of words, based on the max string - length. - """ - - if not isinstance(content, str) or length <= 0: - return content - - if len(content) <= length: - return content - - return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix - - -class SQLDatabase: - """SQLAlchemy wrapper around a database.""" - - def __init__( - self, - engine: Engine, - schema: Optional[str] = None, - metadata: Optional[MetaData] = None, - ignore_tables: Optional[List[str]] = None, - include_tables: Optional[List[str]] = None, - sample_rows_in_table_info: int = 3, - indexes_in_table_info: bool = False, - custom_table_info: Optional[dict] = None, - view_support: bool = False, - max_string_length: int = 300, - ): - """Create engine from database URI.""" - self._engine = engine - self._schema = schema - if include_tables and ignore_tables: - raise ValueError("Cannot specify both include_tables and ignore_tables") - - self._inspector = inspect(self._engine) - - # including view support by adding the views as well as tables to the all - # tables list if view_support is True - self._all_tables = set( - self._inspector.get_table_names(schema=schema) - + (self._inspector.get_view_names(schema=schema) if view_support else []) - ) - - self._include_tables = set(include_tables) if include_tables else set() - if self._include_tables: - missing_tables = self._include_tables - self._all_tables - if missing_tables: - raise ValueError( - f"include_tables {missing_tables} not found in database" - ) - self._ignore_tables = set(ignore_tables) if ignore_tables else set() - if self._ignore_tables: - missing_tables = self._ignore_tables - self._all_tables - if missing_tables: - raise ValueError( - f"ignore_tables {missing_tables} not found in database" - ) - usable_tables = self.get_usable_table_names() - self._usable_tables = set(usable_tables) if usable_tables else self._all_tables - - if not isinstance(sample_rows_in_table_info, int): - raise TypeError("sample_rows_in_table_info must be an integer") - - self._sample_rows_in_table_info = sample_rows_in_table_info - self._indexes_in_table_info = indexes_in_table_info - - self._custom_table_info = custom_table_info - if self._custom_table_info: - if not isinstance(self._custom_table_info, dict): - raise TypeError( - "table_info must be a dictionary with table names as keys and the " - "desired table info as values" - ) - # only keep the tables that are also present in the database - intersection = set(self._custom_table_info).intersection(self._all_tables) - self._custom_table_info = dict( - (table, self._custom_table_info[table]) - for table in self._custom_table_info - if table in intersection - ) - - self._max_string_length = max_string_length - - self._metadata = metadata or MetaData() - # including view support if view_support = true - self._metadata.reflect( - views=view_support, - bind=self._engine, - only=list(self._usable_tables), - schema=self._schema, - ) - - @classmethod - def from_uri( - cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any - ) -> SQLDatabase: - """Construct a SQLAlchemy engine from URI.""" - _engine_args = engine_args or {} - return cls(create_engine(database_uri, **_engine_args), **kwargs) - - @classmethod - def from_databricks( - cls, - catalog: str, - schema: str, - host: Optional[str] = None, - api_token: Optional[str] = None, - warehouse_id: Optional[str] = None, - cluster_id: Optional[str] = None, - engine_args: Optional[dict] = None, - **kwargs: Any, - ) -> SQLDatabase: - """ - Class method to create an SQLDatabase instance from a Databricks connection. - This method requires the 'databricks-sql-connector' package. If not installed, - it can be added using `pip install databricks-sql-connector`. - - Args: - catalog (str): The catalog name in the Databricks database. - schema (str): The schema name in the catalog. - host (Optional[str]): The Databricks workspace hostname, excluding - 'https://' part. If not provided, it attempts to fetch from the - environment variable 'DATABRICKS_HOST'. If still unavailable and if - running in a Databricks notebook, it defaults to the current workspace - hostname. Defaults to None. - api_token (Optional[str]): The Databricks personal access token for - accessing the Databricks SQL warehouse or the cluster. If not provided, - it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable - and running in a Databricks notebook, a temporary token for the current - user is generated. Defaults to None. - warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If - provided, the method configures the connection to use this warehouse. - Cannot be used with 'cluster_id'. Defaults to None. - cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If - provided, the method configures the connection to use this cluster. - Cannot be used with 'warehouse_id'. If running in a Databricks notebook - and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the - cluster the notebook is attached to. Defaults to None. - engine_args (Optional[dict]): The arguments to be used when connecting - Databricks. Defaults to None. - **kwargs (Any): Additional keyword arguments for the `from_uri` method. - - Returns: - SQLDatabase: An instance of SQLDatabase configured with the provided - Databricks connection details. - - Raises: - ValueError: If 'databricks-sql-connector' is not found, or if both - 'warehouse_id' and 'cluster_id' are provided, or if neither - 'warehouse_id' nor 'cluster_id' are provided and it's not executing - inside a Databricks notebook. - """ - try: - from databricks import sql # noqa: F401 - except ImportError: - raise ValueError( - "databricks-sql-connector package not found, please install with" - " `pip install databricks-sql-connector`" - ) - context = None - try: - from dbruntime.databricks_repl_context import get_context - - context = get_context() - except ImportError: - pass - - default_host = context.browserHostName if context else None - if host is None: - host = get_from_env("host", "DATABRICKS_HOST", default_host) - - default_api_token = context.apiToken if context else None - if api_token is None: - api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token) - - if warehouse_id is None and cluster_id is None: - if context: - cluster_id = context.clusterId - else: - raise ValueError( - "Need to provide either 'warehouse_id' or 'cluster_id'." - ) - - if warehouse_id and cluster_id: - raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.") - - if warehouse_id: - http_path = f"/sql/1.0/warehouses/{warehouse_id}" - else: - http_path = f"/sql/protocolv1/o/0/{cluster_id}" - - uri = ( - f"databricks://token:{api_token}@{host}?" - f"http_path={http_path}&catalog={catalog}&schema={schema}" - ) - return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs) - - @classmethod - def from_cnosdb( - cls, - url: str = "127.0.0.1:8902", - user: str = "root", - password: str = "", - tenant: str = "cnosdb", - database: str = "public", - ) -> SQLDatabase: - """ - Class method to create an SQLDatabase instance from a CnosDB connection. - This method requires the 'cnos-connector' package. If not installed, it - can be added using `pip install cnos-connector`. - - Args: - url (str): The HTTP connection host name and port number of the CnosDB - service, excluding "http://" or "https://", with a default value - of "127.0.0.1:8902". - user (str): The username used to connect to the CnosDB service, with a - default value of "root". - password (str): The password of the user connecting to the CnosDB service, - with a default value of "". - tenant (str): The name of the tenant used to connect to the CnosDB service, - with a default value of "cnosdb". - database (str): The name of the database in the CnosDB tenant. - - Returns: - SQLDatabase: An instance of SQLDatabase configured with the provided - CnosDB connection details. - """ - try: - from cnosdb_connector import make_cnosdb_langchain_uri - - uri = make_cnosdb_langchain_uri(url, user, password, tenant, database) - return cls.from_uri(database_uri=uri) - except ImportError: - raise ValueError( - "cnos-connector package not found, please install with" - " `pip install cnos-connector`" - ) - - @property - def dialect(self) -> str: - """Return string representation of dialect to use.""" - return self._engine.dialect.name - - def get_usable_table_names(self) -> Iterable[str]: - """Get names of tables available.""" - if self._include_tables: - return sorted(self._include_tables) - return sorted(self._all_tables - self._ignore_tables) - - def get_table_names(self) -> Iterable[str]: - """Get names of tables available.""" - warnings.warn( - "This method is deprecated - please use `get_usable_table_names`." - ) - return self.get_usable_table_names() - - @property - def table_info(self) -> str: - """Information about all tables in the database.""" - return self.get_table_info() - - def get_table_info(self, table_names: Optional[List[str]] = None) -> str: - """Get information about specified tables. - - Follows best practices as specified in: Rajkumar et al, 2022 - (https://arxiv.org/abs/2204.00498) - - If `sample_rows_in_table_info`, the specified number of sample rows will be - appended to each table description. This can increase performance as - demonstrated in the paper. - """ - all_table_names = self.get_usable_table_names() - if table_names is not None: - missing_tables = set(table_names).difference(all_table_names) - if missing_tables: - raise ValueError(f"table_names {missing_tables} not found in database") - all_table_names = table_names - - meta_tables = [ - tbl - for tbl in self._metadata.sorted_tables - if tbl.name in set(all_table_names) - and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) - ] - - tables = [] - for table in meta_tables: - if self._custom_table_info and table.name in self._custom_table_info: - tables.append(self._custom_table_info[table.name]) - continue - - # add create table command - create_table = str(CreateTable(table).compile(self._engine)) - table_info = f"{create_table.rstrip()}" - has_extra_info = ( - self._indexes_in_table_info or self._sample_rows_in_table_info - ) - if has_extra_info: - table_info += "\n\n/*" - if self._indexes_in_table_info: - table_info += f"\n{self._get_table_indexes(table)}\n" - if self._sample_rows_in_table_info: - table_info += f"\n{self._get_sample_rows(table)}\n" - if has_extra_info: - table_info += "*/" - tables.append(table_info) - tables.sort() - final_str = "\n\n".join(tables) - return final_str - - def _get_table_indexes(self, table: Table) -> str: - indexes = self._inspector.get_indexes(table.name) - indexes_formatted = "\n".join(map(_format_index, indexes)) - return f"Table Indexes:\n{indexes_formatted}" - - def _get_sample_rows(self, table: Table) -> str: - # build the select command - command = select(table).limit(self._sample_rows_in_table_info) - - # save the columns in string format - columns_str = "\t".join([col.name for col in table.columns]) - - try: - # get the sample rows - with self._engine.connect() as connection: - sample_rows_result = connection.execute(command) # type: ignore - # shorten values in the sample rows - sample_rows = list( - map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) - ) - - # save the sample rows in string format - sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) - - # in some dialects when there are no rows in the table a - # 'ProgrammingError' is returned - except ProgrammingError: - sample_rows_str = "" - - return ( - f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" - f"{columns_str}\n" - f"{sample_rows_str}" - ) - - def run(self, command: str, fetch: str = "all") -> str: - """Execute a SQL command and return a string representing the results. - - If the statement returns rows, a string of the results is returned. - If the statement returns no rows, an empty string is returned. - - """ - with self._engine.begin() as connection: - if self._schema is not None: - if self.dialect == "snowflake": - connection.exec_driver_sql( - f"ALTER SESSION SET search_path='{self._schema}'" - ) - elif self.dialect == "bigquery": - connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'") - else: - connection.exec_driver_sql(f"SET search_path TO {self._schema}") - cursor = connection.execute(text(command)) - if cursor.returns_rows: - if fetch == "all": - result = cursor.fetchall() - elif fetch == "one": - result = cursor.fetchone() # type: ignore - else: - raise ValueError("Fetch parameter must be either 'one' or 'all'") - - # Convert columns values to string to avoid issues with sqlalchmey - # trunacating text - if isinstance(result, list): - return str( - [ - tuple( - truncate_word(c, length=self._max_string_length) - for c in r - ) - for r in result - ] - ) - - return str( - tuple( - truncate_word(c, length=self._max_string_length) for c in result - ) - ) - return "" - - def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: - """Get information about specified tables. - - Follows best practices as specified in: Rajkumar et al, 2022 - (https://arxiv.org/abs/2204.00498) - - If `sample_rows_in_table_info`, the specified number of sample rows will be - appended to each table description. This can increase performance as - demonstrated in the paper. - """ - try: - return self.get_table_info(table_names) - except ValueError as e: - """Format the error message""" - return f"Error: {e}" - - def run_no_throw(self, command: str, fetch: str = "all") -> str: - """Execute a SQL command and return a string representing the results. - - If the statement returns rows, a string of the results is returned. - If the statement returns no rows, an empty string is returned. - - If the statement throws an error, the error message is returned. - """ - try: - return self.run(command, fetch) - except SQLAlchemyError as e: - """Format the error message""" - return f"Error: {e}" +__all__ = ["SQLDatabase"] diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index 5ab9a10ef45..6901ea061c2 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -11,7 +11,7 @@ from langchain.callbacks.manager import ( ) from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate -from langchain.sql_database import SQLDatabase +from langchain.utilities.sql_database import SQLDatabase from langchain.tools.base import BaseTool from langchain.tools.sql_database.prompt import QUERY_CHECKER diff --git a/langchain/utilities/__init__.py b/langchain/utilities/__init__.py index e17b5f54ba8..1564098a101 100644 --- a/langchain/utilities/__init__.py +++ b/langchain/utilities/__init__.py @@ -25,6 +25,7 @@ from langchain.utilities.scenexplain import SceneXplainAPIWrapper from langchain.utilities.searx_search import SearxSearchWrapper from langchain.utilities.serpapi import SerpAPIWrapper from langchain.utilities.spark_sql import SparkSQL +from langchain.utilities.sql_database import SQLDatabase from langchain.utilities.twilio import TwilioAPIWrapper from langchain.utilities.wikipedia import WikipediaAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper @@ -56,6 +57,7 @@ __all__ = [ "SearxSearchWrapper", "SerpAPIWrapper", "SparkSQL", + "SQLDatabase", "TextRequestsWrapper", "TwilioAPIWrapper", "WikipediaAPIWrapper", diff --git a/langchain/utilities/sql_database.py b/langchain/utilities/sql_database.py new file mode 100644 index 00000000000..43a43e86fd9 --- /dev/null +++ b/langchain/utilities/sql_database.py @@ -0,0 +1,445 @@ +"""SQLAlchemy wrapper around a database.""" +from __future__ import annotations + +import warnings +from typing import Any, Iterable, List, Optional + +import sqlalchemy +from sqlalchemy import MetaData, Table, create_engine, inspect, select, text +from sqlalchemy.engine import Engine +from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from sqlalchemy.schema import CreateTable + +from langchain.utils import get_from_env + + +def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: + return ( + f'Name: {index["name"]}, Unique: {index["unique"]},' + f' Columns: {str(index["column_names"])}' + ) + + +def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: + """ + Truncate a string to a certain number of words, based on the max string + length. + """ + + if not isinstance(content, str) or length <= 0: + return content + + if len(content) <= length: + return content + + return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix + + +class SQLDatabase: + """SQLAlchemy wrapper around a database.""" + + def __init__( + self, + engine: Engine, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 3, + indexes_in_table_info: bool = False, + custom_table_info: Optional[dict] = None, + view_support: bool = False, + max_string_length: int = 300, + ): + """Create engine from database URI.""" + self._engine = engine + self._schema = schema + if include_tables and ignore_tables: + raise ValueError("Cannot specify both include_tables and ignore_tables") + + self._inspector = inspect(self._engine) + + # including view support by adding the views as well as tables to the all + # tables list if view_support is True + self._all_tables = set( + self._inspector.get_table_names(schema=schema) + + (self._inspector.get_view_names(schema=schema) if view_support else []) + ) + + self._include_tables = set(include_tables) if include_tables else set() + if self._include_tables: + missing_tables = self._include_tables - self._all_tables + if missing_tables: + raise ValueError( + f"include_tables {missing_tables} not found in database" + ) + self._ignore_tables = set(ignore_tables) if ignore_tables else set() + if self._ignore_tables: + missing_tables = self._ignore_tables - self._all_tables + if missing_tables: + raise ValueError( + f"ignore_tables {missing_tables} not found in database" + ) + usable_tables = self.get_usable_table_names() + self._usable_tables = set(usable_tables) if usable_tables else self._all_tables + + if not isinstance(sample_rows_in_table_info, int): + raise TypeError("sample_rows_in_table_info must be an integer") + + self._sample_rows_in_table_info = sample_rows_in_table_info + self._indexes_in_table_info = indexes_in_table_info + + self._custom_table_info = custom_table_info + if self._custom_table_info: + if not isinstance(self._custom_table_info, dict): + raise TypeError( + "table_info must be a dictionary with table names as keys and the " + "desired table info as values" + ) + # only keep the tables that are also present in the database + intersection = set(self._custom_table_info).intersection(self._all_tables) + self._custom_table_info = dict( + (table, self._custom_table_info[table]) + for table in self._custom_table_info + if table in intersection + ) + + self._max_string_length = max_string_length + + self._metadata = metadata or MetaData() + # including view support if view_support = true + self._metadata.reflect( + views=view_support, + bind=self._engine, + only=list(self._usable_tables), + schema=self._schema, + ) + + @classmethod + def from_uri( + cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any + ) -> SQLDatabase: + """Construct a SQLAlchemy engine from URI.""" + _engine_args = engine_args or {} + return cls(create_engine(database_uri, **_engine_args), **kwargs) + + @classmethod + def from_databricks( + cls, + catalog: str, + schema: str, + host: Optional[str] = None, + api_token: Optional[str] = None, + warehouse_id: Optional[str] = None, + cluster_id: Optional[str] = None, + engine_args: Optional[dict] = None, + **kwargs: Any, + ) -> SQLDatabase: + """ + Class method to create an SQLDatabase instance from a Databricks connection. + This method requires the 'databricks-sql-connector' package. If not installed, + it can be added using `pip install databricks-sql-connector`. + + Args: + catalog (str): The catalog name in the Databricks database. + schema (str): The schema name in the catalog. + host (Optional[str]): The Databricks workspace hostname, excluding + 'https://' part. If not provided, it attempts to fetch from the + environment variable 'DATABRICKS_HOST'. If still unavailable and if + running in a Databricks notebook, it defaults to the current workspace + hostname. Defaults to None. + api_token (Optional[str]): The Databricks personal access token for + accessing the Databricks SQL warehouse or the cluster. If not provided, + it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable + and running in a Databricks notebook, a temporary token for the current + user is generated. Defaults to None. + warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If + provided, the method configures the connection to use this warehouse. + Cannot be used with 'cluster_id'. Defaults to None. + cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If + provided, the method configures the connection to use this cluster. + Cannot be used with 'warehouse_id'. If running in a Databricks notebook + and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the + cluster the notebook is attached to. Defaults to None. + engine_args (Optional[dict]): The arguments to be used when connecting + Databricks. Defaults to None. + **kwargs (Any): Additional keyword arguments for the `from_uri` method. + + Returns: + SQLDatabase: An instance of SQLDatabase configured with the provided + Databricks connection details. + + Raises: + ValueError: If 'databricks-sql-connector' is not found, or if both + 'warehouse_id' and 'cluster_id' are provided, or if neither + 'warehouse_id' nor 'cluster_id' are provided and it's not executing + inside a Databricks notebook. + """ + try: + from databricks import sql # noqa: F401 + except ImportError: + raise ValueError( + "databricks-sql-connector package not found, please install with" + " `pip install databricks-sql-connector`" + ) + context = None + try: + from dbruntime.databricks_repl_context import get_context + + context = get_context() + except ImportError: + pass + + default_host = context.browserHostName if context else None + if host is None: + host = get_from_env("host", "DATABRICKS_HOST", default_host) + + default_api_token = context.apiToken if context else None + if api_token is None: + api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token) + + if warehouse_id is None and cluster_id is None: + if context: + cluster_id = context.clusterId + else: + raise ValueError( + "Need to provide either 'warehouse_id' or 'cluster_id'." + ) + + if warehouse_id and cluster_id: + raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.") + + if warehouse_id: + http_path = f"/sql/1.0/warehouses/{warehouse_id}" + else: + http_path = f"/sql/protocolv1/o/0/{cluster_id}" + + uri = ( + f"databricks://token:{api_token}@{host}?" + f"http_path={http_path}&catalog={catalog}&schema={schema}" + ) + return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs) + + @classmethod + def from_cnosdb( + cls, + url: str = "127.0.0.1:8902", + user: str = "root", + password: str = "", + tenant: str = "cnosdb", + database: str = "public", + ) -> SQLDatabase: + """ + Class method to create an SQLDatabase instance from a CnosDB connection. + This method requires the 'cnos-connector' package. If not installed, it + can be added using `pip install cnos-connector`. + + Args: + url (str): The HTTP connection host name and port number of the CnosDB + service, excluding "http://" or "https://", with a default value + of "127.0.0.1:8902". + user (str): The username used to connect to the CnosDB service, with a + default value of "root". + password (str): The password of the user connecting to the CnosDB service, + with a default value of "". + tenant (str): The name of the tenant used to connect to the CnosDB service, + with a default value of "cnosdb". + database (str): The name of the database in the CnosDB tenant. + + Returns: + SQLDatabase: An instance of SQLDatabase configured with the provided + CnosDB connection details. + """ + try: + from cnosdb_connector import make_cnosdb_langchain_uri + + uri = make_cnosdb_langchain_uri(url, user, password, tenant, database) + return cls.from_uri(database_uri=uri) + except ImportError: + raise ValueError( + "cnos-connector package not found, please install with" + " `pip install cnos-connector`" + ) + + @property + def dialect(self) -> str: + """Return string representation of dialect to use.""" + return self._engine.dialect.name + + def get_usable_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + if self._include_tables: + return sorted(self._include_tables) + return sorted(self._all_tables - self._ignore_tables) + + def get_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + warnings.warn( + "This method is deprecated - please use `get_usable_table_names`." + ) + return self.get_usable_table_names() + + @property + def table_info(self) -> str: + """Information about all tables in the database.""" + return self.get_table_info() + + def get_table_info(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables. + + Follows best practices as specified in: Rajkumar et al, 2022 + (https://arxiv.org/abs/2204.00498) + + If `sample_rows_in_table_info`, the specified number of sample rows will be + appended to each table description. This can increase performance as + demonstrated in the paper. + """ + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + raise ValueError(f"table_names {missing_tables} not found in database") + all_table_names = table_names + + meta_tables = [ + tbl + for tbl in self._metadata.sorted_tables + if tbl.name in set(all_table_names) + and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) + ] + + tables = [] + for table in meta_tables: + if self._custom_table_info and table.name in self._custom_table_info: + tables.append(self._custom_table_info[table.name]) + continue + + # add create table command + create_table = str(CreateTable(table).compile(self._engine)) + table_info = f"{create_table.rstrip()}" + has_extra_info = ( + self._indexes_in_table_info or self._sample_rows_in_table_info + ) + if has_extra_info: + table_info += "\n\n/*" + if self._indexes_in_table_info: + table_info += f"\n{self._get_table_indexes(table)}\n" + if self._sample_rows_in_table_info: + table_info += f"\n{self._get_sample_rows(table)}\n" + if has_extra_info: + table_info += "*/" + tables.append(table_info) + tables.sort() + final_str = "\n\n".join(tables) + return final_str + + def _get_table_indexes(self, table: Table) -> str: + indexes = self._inspector.get_indexes(table.name) + indexes_formatted = "\n".join(map(_format_index, indexes)) + return f"Table Indexes:\n{indexes_formatted}" + + def _get_sample_rows(self, table: Table) -> str: + # build the select command + command = select(table).limit(self._sample_rows_in_table_info) + + # save the columns in string format + columns_str = "\t".join([col.name for col in table.columns]) + + try: + # get the sample rows + with self._engine.connect() as connection: + sample_rows_result = connection.execute(command) # type: ignore + # shorten values in the sample rows + sample_rows = list( + map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) + ) + + # save the sample rows in string format + sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) + + # in some dialects when there are no rows in the table a + # 'ProgrammingError' is returned + except ProgrammingError: + sample_rows_str = "" + + return ( + f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" + f"{columns_str}\n" + f"{sample_rows_str}" + ) + + def run(self, command: str, fetch: str = "all") -> str: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + """ + with self._engine.begin() as connection: + if self._schema is not None: + if self.dialect == "snowflake": + connection.exec_driver_sql( + f"ALTER SESSION SET search_path='{self._schema}'" + ) + elif self.dialect == "bigquery": + connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'") + else: + connection.exec_driver_sql(f"SET search_path TO {self._schema}") + cursor = connection.execute(text(command)) + if cursor.returns_rows: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone() # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + + # Convert columns values to string to avoid issues with sqlalchmey + # trunacating text + if isinstance(result, list): + return str( + [ + tuple( + truncate_word(c, length=self._max_string_length) + for c in r + ) + for r in result + ] + ) + + return str( + tuple( + truncate_word(c, length=self._max_string_length) for c in result + ) + ) + return "" + + def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: + """Get information about specified tables. + + Follows best practices as specified in: Rajkumar et al, 2022 + (https://arxiv.org/abs/2204.00498) + + If `sample_rows_in_table_info`, the specified number of sample rows will be + appended to each table description. This can increase performance as + demonstrated in the paper. + """ + try: + return self.get_table_info(table_names) + except ValueError as e: + """Format the error message""" + return f"Error: {e}" + + def run_no_throw(self, command: str, fetch: str = "all") -> str: + """Execute a SQL command and return a string representing the results. + + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + + If the statement throws an error, the error message is returned. + """ + try: + return self.run(command, fetch) + except SQLAlchemyError as e: + """Format the error message""" + return f"Error: {e}" diff --git a/tests/integration_tests/chains/test_sql_database.py b/tests/integration_tests/chains/test_sql_database.py index be6fa115325..18be75bf7ca 100644 --- a/tests/integration_tests/chains/test_sql_database.py +++ b/tests/integration_tests/chains/test_sql_database.py @@ -6,7 +6,7 @@ from langchain.chains.sql_database.base import ( SQLDatabaseSequentialChain, ) from langchain.llms.openai import OpenAI -from langchain.sql_database import SQLDatabase +from langchain.utilities.sql_database import SQLDatabase metadata_obj = MetaData() diff --git a/tests/unit_tests/agents/test_sql.py b/tests/unit_tests/agents/test_sql.py index 89b8f90df2c..5bb2d272111 100644 --- a/tests/unit_tests/agents/test_sql.py +++ b/tests/unit_tests/agents/test_sql.py @@ -1,6 +1,6 @@ from langchain.agents import create_sql_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit -from langchain.sql_database import SQLDatabase +from langchain.utilities.sql_database import SQLDatabase from tests.unit_tests.llms.fake_llm import FakeLLM diff --git a/tests/unit_tests/test_sql_database.py b/tests/unit_tests/test_sql_database.py index 80f00950a1a..a5513be2d04 100644 --- a/tests/unit_tests/test_sql_database.py +++ b/tests/unit_tests/test_sql_database.py @@ -12,7 +12,7 @@ from sqlalchemy import ( insert, ) -from langchain.sql_database import SQLDatabase, truncate_word +from langchain.utilities.sql_database import SQLDatabase, truncate_word metadata_obj = MetaData() diff --git a/tests/unit_tests/test_sql_database_schema.py b/tests/unit_tests/test_sql_database_schema.py index a2661b4dcb1..b2a6589463d 100644 --- a/tests/unit_tests/test_sql_database_schema.py +++ b/tests/unit_tests/test_sql_database_schema.py @@ -18,7 +18,7 @@ from sqlalchemy import ( schema, ) -from langchain.sql_database import SQLDatabase +from langchain.utilities.sql_database import SQLDatabase metadata_obj = MetaData()