mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-25 21:03:11 +00:00 
			
		
		
		
	Previously, if this did not find a mypy cache then it wouldnt run this makes it always run adding mypy ignore comments with existing uncaught issues to unblock other prs --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
		
			
				
	
	
		
			494 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			494 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """SQLAlchemy wrapper around a database."""
 | |
| from __future__ import annotations
 | |
| 
 | |
| from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
 | |
| 
 | |
| import sqlalchemy
 | |
| from langchain_core._api import deprecated
 | |
| from langchain_core.utils import get_from_env
 | |
| from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
 | |
| from sqlalchemy.engine import Engine
 | |
| from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
 | |
| from sqlalchemy.schema import CreateTable
 | |
| from sqlalchemy.types import NullType
 | |
| 
 | |
| 
 | |
| def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
 | |
|     return (
 | |
|         f'Name: {index["name"]}, Unique: {index["unique"]},'
 | |
|         f' Columns: {str(index["column_names"])}'
 | |
|     )
 | |
| 
 | |
| 
 | |
| def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
 | |
|     """
 | |
|     Truncate a string to a certain number of words, based on the max string
 | |
|     length.
 | |
|     """
 | |
| 
 | |
|     if not isinstance(content, str) or length <= 0:
 | |
|         return content
 | |
| 
 | |
|     if len(content) <= length:
 | |
|         return content
 | |
| 
 | |
|     return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
 | |
| 
 | |
| 
 | |
| class SQLDatabase:
 | |
|     """SQLAlchemy wrapper around a database."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         engine: Engine,
 | |
|         schema: Optional[str] = None,
 | |
|         metadata: Optional[MetaData] = None,
 | |
|         ignore_tables: Optional[List[str]] = None,
 | |
|         include_tables: Optional[List[str]] = None,
 | |
|         sample_rows_in_table_info: int = 3,
 | |
|         indexes_in_table_info: bool = False,
 | |
|         custom_table_info: Optional[dict] = None,
 | |
|         view_support: bool = False,
 | |
|         max_string_length: int = 300,
 | |
|     ):
 | |
|         """Create engine from database URI."""
 | |
|         self._engine = engine
 | |
|         self._schema = schema
 | |
|         if include_tables and ignore_tables:
 | |
|             raise ValueError("Cannot specify both include_tables and ignore_tables")
 | |
| 
 | |
|         self._inspector = inspect(self._engine)
 | |
| 
 | |
|         # including view support by adding the views as well as tables to the all
 | |
|         # tables list if view_support is True
 | |
|         self._all_tables = set(
 | |
|             self._inspector.get_table_names(schema=schema)
 | |
|             + (self._inspector.get_view_names(schema=schema) if view_support else [])
 | |
|         )
 | |
| 
 | |
|         self._include_tables = set(include_tables) if include_tables else set()
 | |
|         if self._include_tables:
 | |
|             missing_tables = self._include_tables - self._all_tables
 | |
|             if missing_tables:
 | |
|                 raise ValueError(
 | |
|                     f"include_tables {missing_tables} not found in database"
 | |
|                 )
 | |
|         self._ignore_tables = set(ignore_tables) if ignore_tables else set()
 | |
|         if self._ignore_tables:
 | |
|             missing_tables = self._ignore_tables - self._all_tables
 | |
|             if missing_tables:
 | |
|                 raise ValueError(
 | |
|                     f"ignore_tables {missing_tables} not found in database"
 | |
|                 )
 | |
|         usable_tables = self.get_usable_table_names()
 | |
|         self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
 | |
| 
 | |
|         if not isinstance(sample_rows_in_table_info, int):
 | |
|             raise TypeError("sample_rows_in_table_info must be an integer")
 | |
| 
 | |
|         self._sample_rows_in_table_info = sample_rows_in_table_info
 | |
|         self._indexes_in_table_info = indexes_in_table_info
 | |
| 
 | |
|         self._custom_table_info = custom_table_info
 | |
|         if self._custom_table_info:
 | |
|             if not isinstance(self._custom_table_info, dict):
 | |
|                 raise TypeError(
 | |
|                     "table_info must be a dictionary with table names as keys and the "
 | |
|                     "desired table info as values"
 | |
|                 )
 | |
|             # only keep the tables that are also present in the database
 | |
|             intersection = set(self._custom_table_info).intersection(self._all_tables)
 | |
|             self._custom_table_info = dict(
 | |
|                 (table, self._custom_table_info[table])
 | |
|                 for table in self._custom_table_info
 | |
|                 if table in intersection
 | |
|             )
 | |
| 
 | |
|         self._max_string_length = max_string_length
 | |
| 
 | |
|         self._metadata = metadata or MetaData()
 | |
|         # including view support if view_support = true
 | |
|         self._metadata.reflect(
 | |
|             views=view_support,
 | |
|             bind=self._engine,
 | |
|             only=list(self._usable_tables),
 | |
|             schema=self._schema,
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def from_uri(
 | |
|         cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
 | |
|     ) -> SQLDatabase:
 | |
|         """Construct a SQLAlchemy engine from URI."""
 | |
|         _engine_args = engine_args or {}
 | |
|         return cls(create_engine(database_uri, **_engine_args), **kwargs)
 | |
| 
 | |
|     @classmethod
 | |
|     def from_databricks(
 | |
|         cls,
 | |
|         catalog: str,
 | |
|         schema: str,
 | |
|         host: Optional[str] = None,
 | |
|         api_token: Optional[str] = None,
 | |
|         warehouse_id: Optional[str] = None,
 | |
|         cluster_id: Optional[str] = None,
 | |
|         engine_args: Optional[dict] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> SQLDatabase:
 | |
|         """
 | |
|         Class method to create an SQLDatabase instance from a Databricks connection.
 | |
|         This method requires the 'databricks-sql-connector' package. If not installed,
 | |
|         it can be added using `pip install databricks-sql-connector`.
 | |
| 
 | |
|         Args:
 | |
|             catalog (str): The catalog name in the Databricks database.
 | |
|             schema (str): The schema name in the catalog.
 | |
|             host (Optional[str]): The Databricks workspace hostname, excluding
 | |
|                 'https://' part. If not provided, it attempts to fetch from the
 | |
|                 environment variable 'DATABRICKS_HOST'. If still unavailable and if
 | |
|                 running in a Databricks notebook, it defaults to the current workspace
 | |
|                 hostname. Defaults to None.
 | |
|             api_token (Optional[str]): The Databricks personal access token for
 | |
|                 accessing the Databricks SQL warehouse or the cluster. If not provided,
 | |
|                 it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable
 | |
|                 and running in a Databricks notebook, a temporary token for the current
 | |
|                 user is generated. Defaults to None.
 | |
|             warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
 | |
|                 provided, the method configures the connection to use this warehouse.
 | |
|                 Cannot be used with 'cluster_id'. Defaults to None.
 | |
|             cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
 | |
|                 provided, the method configures the connection to use this cluster.
 | |
|                 Cannot be used with 'warehouse_id'. If running in a Databricks notebook
 | |
|                 and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
 | |
|                 cluster the notebook is attached to. Defaults to None.
 | |
|             engine_args (Optional[dict]): The arguments to be used when connecting
 | |
|                 Databricks. Defaults to None.
 | |
|             **kwargs (Any): Additional keyword arguments for the `from_uri` method.
 | |
| 
 | |
|         Returns:
 | |
|             SQLDatabase: An instance of SQLDatabase configured with the provided
 | |
|                 Databricks connection details.
 | |
| 
 | |
|         Raises:
 | |
|             ValueError: If 'databricks-sql-connector' is not found, or if both
 | |
|                 'warehouse_id' and 'cluster_id' are provided, or if neither
 | |
|                 'warehouse_id' nor 'cluster_id' are provided and it's not executing
 | |
|                 inside a Databricks notebook.
 | |
|         """
 | |
|         try:
 | |
|             from databricks import sql  # noqa: F401
 | |
|         except ImportError:
 | |
|             raise ValueError(
 | |
|                 "databricks-sql-connector package not found, please install with"
 | |
|                 " `pip install databricks-sql-connector`"
 | |
|             )
 | |
|         context = None
 | |
|         try:
 | |
|             from dbruntime.databricks_repl_context import get_context
 | |
| 
 | |
|             context = get_context()
 | |
|         except ImportError:
 | |
|             pass
 | |
| 
 | |
|         default_host = context.browserHostName if context else None
 | |
|         if host is None:
 | |
|             host = get_from_env("host", "DATABRICKS_HOST", default_host)
 | |
| 
 | |
|         default_api_token = context.apiToken if context else None
 | |
|         if api_token is None:
 | |
|             api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
 | |
| 
 | |
|         if warehouse_id is None and cluster_id is None:
 | |
|             if context:
 | |
|                 cluster_id = context.clusterId
 | |
|             else:
 | |
|                 raise ValueError(
 | |
|                     "Need to provide either 'warehouse_id' or 'cluster_id'."
 | |
|                 )
 | |
| 
 | |
|         if warehouse_id and cluster_id:
 | |
|             raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
 | |
| 
 | |
|         if warehouse_id:
 | |
|             http_path = f"/sql/1.0/warehouses/{warehouse_id}"
 | |
|         else:
 | |
|             http_path = f"/sql/protocolv1/o/0/{cluster_id}"
 | |
| 
 | |
|         uri = (
 | |
|             f"databricks://token:{api_token}@{host}?"
 | |
|             f"http_path={http_path}&catalog={catalog}&schema={schema}"
 | |
|         )
 | |
|         return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
 | |
| 
 | |
|     @classmethod
 | |
|     def from_cnosdb(
 | |
|         cls,
 | |
|         url: str = "127.0.0.1:8902",
 | |
|         user: str = "root",
 | |
|         password: str = "",
 | |
|         tenant: str = "cnosdb",
 | |
|         database: str = "public",
 | |
|     ) -> SQLDatabase:
 | |
|         """
 | |
|         Class method to create an SQLDatabase instance from a CnosDB connection.
 | |
|         This method requires the 'cnos-connector' package. If not installed, it
 | |
|         can be added using `pip install cnos-connector`.
 | |
| 
 | |
|         Args:
 | |
|             url (str): The HTTP connection host name and port number of the CnosDB
 | |
|                 service, excluding "http://" or "https://", with a default value
 | |
|                 of "127.0.0.1:8902".
 | |
|             user (str): The username used to connect to the CnosDB service, with a
 | |
|                 default value of "root".
 | |
|             password (str): The password of the user connecting to the CnosDB service,
 | |
|                 with a default value of "".
 | |
|             tenant (str): The name of the tenant used to connect to the CnosDB service,
 | |
|                 with a default value of "cnosdb".
 | |
|             database (str): The name of the database in the CnosDB tenant.
 | |
| 
 | |
|         Returns:
 | |
|             SQLDatabase: An instance of SQLDatabase configured with the provided
 | |
|             CnosDB connection details.
 | |
|         """
 | |
|         try:
 | |
|             from cnosdb_connector import make_cnosdb_langchain_uri
 | |
| 
 | |
|             uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
 | |
|             return cls.from_uri(database_uri=uri)
 | |
|         except ImportError:
 | |
|             raise ValueError(
 | |
|                 "cnos-connector package not found, please install with"
 | |
|                 " `pip install cnos-connector`"
 | |
|             )
 | |
| 
 | |
|     @property
 | |
|     def dialect(self) -> str:
 | |
|         """Return string representation of dialect to use."""
 | |
|         return self._engine.dialect.name
 | |
| 
 | |
|     def get_usable_table_names(self) -> Iterable[str]:
 | |
|         """Get names of tables available."""
 | |
|         if self._include_tables:
 | |
|             return sorted(self._include_tables)
 | |
|         return sorted(self._all_tables - self._ignore_tables)
 | |
| 
 | |
|     @deprecated("0.0.1", alternative="get_usable_table_name", removal="0.2.0")
 | |
|     def get_table_names(self) -> Iterable[str]:
 | |
|         """Get names of tables available."""
 | |
|         return self.get_usable_table_names()
 | |
| 
 | |
|     @property
 | |
|     def table_info(self) -> str:
 | |
|         """Information about all tables in the database."""
 | |
|         return self.get_table_info()
 | |
| 
 | |
|     def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
 | |
|         """Get information about specified tables.
 | |
| 
 | |
|         Follows best practices as specified in: Rajkumar et al, 2022
 | |
|         (https://arxiv.org/abs/2204.00498)
 | |
| 
 | |
|         If `sample_rows_in_table_info`, the specified number of sample rows will be
 | |
|         appended to each table description. This can increase performance as
 | |
|         demonstrated in the paper.
 | |
|         """
 | |
|         all_table_names = self.get_usable_table_names()
 | |
|         if table_names is not None:
 | |
|             missing_tables = set(table_names).difference(all_table_names)
 | |
|             if missing_tables:
 | |
|                 raise ValueError(f"table_names {missing_tables} not found in database")
 | |
|             all_table_names = table_names
 | |
| 
 | |
|         meta_tables = [
 | |
|             tbl
 | |
|             for tbl in self._metadata.sorted_tables
 | |
|             if tbl.name in set(all_table_names)
 | |
|             and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
 | |
|         ]
 | |
| 
 | |
|         tables = []
 | |
|         for table in meta_tables:
 | |
|             if self._custom_table_info and table.name in self._custom_table_info:
 | |
|                 tables.append(self._custom_table_info[table.name])
 | |
|                 continue
 | |
| 
 | |
|             # Ignore JSON datatyped columns
 | |
|             for k, v in table.columns.items():
 | |
|                 if type(v.type) is NullType:
 | |
|                     table._columns.remove(v)
 | |
| 
 | |
|             # add create table command
 | |
|             create_table = str(CreateTable(table).compile(self._engine))
 | |
|             table_info = f"{create_table.rstrip()}"
 | |
|             has_extra_info = (
 | |
|                 self._indexes_in_table_info or self._sample_rows_in_table_info
 | |
|             )
 | |
|             if has_extra_info:
 | |
|                 table_info += "\n\n/*"
 | |
|             if self._indexes_in_table_info:
 | |
|                 table_info += f"\n{self._get_table_indexes(table)}\n"
 | |
|             if self._sample_rows_in_table_info:
 | |
|                 table_info += f"\n{self._get_sample_rows(table)}\n"
 | |
|             if has_extra_info:
 | |
|                 table_info += "*/"
 | |
|             tables.append(table_info)
 | |
|         tables.sort()
 | |
|         final_str = "\n\n".join(tables)
 | |
|         return final_str
 | |
| 
 | |
|     def _get_table_indexes(self, table: Table) -> str:
 | |
|         indexes = self._inspector.get_indexes(table.name)
 | |
|         indexes_formatted = "\n".join(map(_format_index, indexes))
 | |
|         return f"Table Indexes:\n{indexes_formatted}"
 | |
| 
 | |
|     def _get_sample_rows(self, table: Table) -> str:
 | |
|         # build the select command
 | |
|         command = select(table).limit(self._sample_rows_in_table_info)
 | |
| 
 | |
|         # save the columns in string format
 | |
|         columns_str = "\t".join([col.name for col in table.columns])
 | |
| 
 | |
|         try:
 | |
|             # get the sample rows
 | |
|             with self._engine.connect() as connection:
 | |
|                 sample_rows_result = connection.execute(command)  # type: ignore
 | |
|                 # shorten values in the sample rows
 | |
|                 sample_rows = list(
 | |
|                     map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
 | |
|                 )
 | |
| 
 | |
|             # save the sample rows in string format
 | |
|             sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
 | |
| 
 | |
|         # in some dialects when there are no rows in the table a
 | |
|         # 'ProgrammingError' is returned
 | |
|         except ProgrammingError:
 | |
|             sample_rows_str = ""
 | |
| 
 | |
|         return (
 | |
|             f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
 | |
|             f"{columns_str}\n"
 | |
|             f"{sample_rows_str}"
 | |
|         )
 | |
| 
 | |
|     def _execute(
 | |
|         self,
 | |
|         command: str,
 | |
|         fetch: Literal["all", "one"] = "all",
 | |
|     ) -> Sequence[Dict[str, Any]]:
 | |
|         """
 | |
|         Executes SQL command through underlying engine.
 | |
| 
 | |
|         If the statement returns no rows, an empty list is returned.
 | |
|         """
 | |
|         with self._engine.begin() as connection:  # type: Connection  # type: ignore[name-defined]
 | |
|             if self._schema is not None:
 | |
|                 if self.dialect == "snowflake":
 | |
|                     connection.exec_driver_sql(
 | |
|                         "ALTER SESSION SET search_path = %s", (self._schema,)
 | |
|                     )
 | |
|                 elif self.dialect == "bigquery":
 | |
|                     connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,))
 | |
|                 elif self.dialect == "mssql":
 | |
|                     pass
 | |
|                 elif self.dialect == "trino":
 | |
|                     connection.exec_driver_sql("USE ?", (self._schema,))
 | |
|                 elif self.dialect == "duckdb":
 | |
|                     # Unclear which parameterized argument syntax duckdb supports.
 | |
|                     # The docs for the duckdb client say they support multiple,
 | |
|                     # but `duckdb_engine` seemed to struggle with all of them:
 | |
|                     # https://github.com/Mause/duckdb_engine/issues/796
 | |
|                     connection.exec_driver_sql(f"SET search_path TO {self._schema}")
 | |
|                 elif self.dialect == "oracle":
 | |
|                     connection.exec_driver_sql(
 | |
|                         f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}"
 | |
|                     )
 | |
|                 elif self.dialect == "sqlany":
 | |
|                     # If anybody using Sybase SQL anywhere database then it should not
 | |
|                     # go to else condition. It should be same as mssql.
 | |
|                     pass
 | |
|                 elif self.dialect == "postgresql":  # postgresql
 | |
|                     connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
 | |
| 
 | |
|             cursor = connection.execute(text(command))
 | |
|             if cursor.returns_rows:
 | |
|                 if fetch == "all":
 | |
|                     result = [x._asdict() for x in cursor.fetchall()]
 | |
|                 elif fetch == "one":
 | |
|                     first_result = cursor.fetchone()
 | |
|                     result = [] if first_result is None else [first_result._asdict()]
 | |
|                 else:
 | |
|                     raise ValueError("Fetch parameter must be either 'one' or 'all'")
 | |
|                 return result
 | |
|         return []
 | |
| 
 | |
|     def run(
 | |
|         self,
 | |
|         command: str,
 | |
|         fetch: Literal["all", "one"] = "all",
 | |
|         include_columns: bool = False,
 | |
|     ) -> str:
 | |
|         """Execute a SQL command and return a string representing the results.
 | |
| 
 | |
|         If the statement returns rows, a string of the results is returned.
 | |
|         If the statement returns no rows, an empty string is returned.
 | |
|         """
 | |
|         result = self._execute(command, fetch)
 | |
| 
 | |
|         res = [
 | |
|             {
 | |
|                 column: truncate_word(value, length=self._max_string_length)
 | |
|                 for column, value in r.items()
 | |
|             }
 | |
|             for r in result
 | |
|         ]
 | |
| 
 | |
|         if not include_columns:
 | |
|             res = [tuple(row.values()) for row in res]  # type: ignore[misc]
 | |
| 
 | |
|         if not res:
 | |
|             return ""
 | |
|         else:
 | |
|             return str(res)
 | |
| 
 | |
|     def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
 | |
|         """Get information about specified tables.
 | |
| 
 | |
|         Follows best practices as specified in: Rajkumar et al, 2022
 | |
|         (https://arxiv.org/abs/2204.00498)
 | |
| 
 | |
|         If `sample_rows_in_table_info`, the specified number of sample rows will be
 | |
|         appended to each table description. This can increase performance as
 | |
|         demonstrated in the paper.
 | |
|         """
 | |
|         try:
 | |
|             return self.get_table_info(table_names)
 | |
|         except ValueError as e:
 | |
|             """Format the error message"""
 | |
|             return f"Error: {e}"
 | |
| 
 | |
|     def run_no_throw(
 | |
|         self,
 | |
|         command: str,
 | |
|         fetch: Literal["all", "one"] = "all",
 | |
|         include_columns: bool = False,
 | |
|     ) -> str:
 | |
|         """Execute a SQL command and return a string representing the results.
 | |
| 
 | |
|         If the statement returns rows, a string of the results is returned.
 | |
|         If the statement returns no rows, an empty string is returned.
 | |
| 
 | |
|         If the statement throws an error, the error message is returned.
 | |
|         """
 | |
|         try:
 | |
|             return self.run(command, fetch, include_columns)
 | |
|         except SQLAlchemyError as e:
 | |
|             """Format the error message"""
 | |
|             return f"Error: {e}"
 | |
| 
 | |
|     def get_context(self) -> Dict[str, Any]:
 | |
|         """Return db context that you may want in agent prompt."""
 | |
|         table_names = list(self.get_usable_table_names())
 | |
|         table_info = self.get_table_info_no_throw()
 | |
|         return {"table_info": table_info, "table_names": ", ".join(table_names)}
 |