langchain/libs/community/langchain_community/utilities/sql_database.py
Kyungho Byoun e6b6c07395
community: add HANA dialect to SQLDatabase (#30475)
This PR includes support for HANA dialect in SQLDatabase, which is a
wrapper class for SQLAlchemy.

Currently, it is unable to set schema name when using HANA DB with
Langchain. And, it does not show any message to user so that it makes
hard for user to figure out why the SQL does not work as expected.

Here is the reference document for HANA DB to set schema for the
session.

- [SET SCHEMA Statement (Session
Management)](https://help.sap.com/docs/SAP_HANA_PLATFORM/4fe29514fd584807ac9f2a04f6754767/20fd550375191014b886a338afb4cd5f.html)
2025-03-27 15:19:50 -04:00

598 lines
23 KiB
Python

"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
import re
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
import sqlalchemy
from langchain_core._api import deprecated
from langchain_core.utils import get_from_env
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import URL, Engine, Result
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.sql.expression import Executable
from sqlalchemy.types import NullType
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f"Name: {index['name']}, Unique: {index['unique']},"
f" Columns: {str(index['column_names'])}"
)
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
"""
Truncate a string to a certain number of words, based on the max string
length.
"""
if not isinstance(content, str) or length <= 0:
return content
if len(content) <= length:
return content
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
def sanitize_schema(schema: str) -> str:
"""Sanitize a schema name to only contain letters, digits, and underscores."""
if not re.match(r"^[a-zA-Z0-9_]+$", schema):
raise ValueError(
f"Schema name '{schema}' contains invalid characters. "
"Schema names must contain only letters, digits, and underscores."
)
return schema
class SQLDatabase:
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine: Engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
lazy_table_reflection: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
list(self._inspector.get_table_names(schema=schema))
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
)
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info
self._custom_table_info = custom_table_info
if self._custom_table_info:
if not isinstance(self._custom_table_info, dict):
raise TypeError(
"table_info must be a dictionary with table names as keys and the "
"desired table info as values"
)
# only keep the tables that are also present in the database
intersection = set(self._custom_table_info).intersection(self._all_tables)
self._custom_table_info = dict(
(table, self._custom_table_info[table])
for table in self._custom_table_info
if table in intersection
)
self._max_string_length = max_string_length
self._view_support = view_support
self._metadata = metadata or MetaData()
if not lazy_table_reflection:
# including view support if view_support = true
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
@classmethod
def from_uri(
cls,
database_uri: Union[str, URL],
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> SQLDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
@classmethod
@deprecated(
"0.3.18",
message="For performing structured retrieval using Databricks SQL, "
"see the latest best practices and recommended APIs at "
"https://docs.unitycatalog.io/ai/integrations/langchain/ " # noqa: E501
"instead",
removal="1.0",
)
def from_databricks(
cls,
catalog: str,
schema: str,
host: Optional[str] = None,
api_token: Optional[str] = None,
warehouse_id: Optional[str] = None,
cluster_id: Optional[str] = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> SQLDatabase:
"""
Class method to create an SQLDatabase instance from a Databricks connection.
This method requires the 'databricks-sql-connector' package. If not installed,
it can be added using `pip install databricks-sql-connector`.
Args:
catalog (str): The catalog name in the Databricks database.
schema (str): The schema name in the catalog.
host (Optional[str]): The Databricks workspace hostname, excluding
'https://' part. If not provided, it attempts to fetch from the
environment variable 'DATABRICKS_HOST'. If still unavailable and if
running in a Databricks notebook, it defaults to the current workspace
hostname. Defaults to None.
api_token (Optional[str]): The Databricks personal access token for
accessing the Databricks SQL warehouse or the cluster. If not provided,
it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable
and running in a Databricks notebook, a temporary token for the current
user is generated. Defaults to None.
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
provided, the method configures the connection to use this warehouse.
Cannot be used with 'cluster_id'. Defaults to None.
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
provided, the method configures the connection to use this cluster.
Cannot be used with 'warehouse_id'. If running in a Databricks notebook
and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
cluster the notebook is attached to. Defaults to None.
engine_args (Optional[dict]): The arguments to be used when connecting
Databricks. Defaults to None.
**kwargs (Any): Additional keyword arguments for the `from_uri` method.
Returns:
SQLDatabase: An instance of SQLDatabase configured with the provided
Databricks connection details.
Raises:
ValueError: If 'databricks-sql-connector' is not found, or if both
'warehouse_id' and 'cluster_id' are provided, or if neither
'warehouse_id' nor 'cluster_id' are provided and it's not executing
inside a Databricks notebook.
"""
try:
from databricks import sql # noqa: F401
except ImportError:
raise ImportError(
"databricks-sql-connector package not found, please install with"
" `pip install databricks-sql-connector`"
)
context = None
try:
from dbruntime.databricks_repl_context import get_context
context = get_context()
default_host = context.browserHostName
except (ImportError, AttributeError):
default_host = None
if host is None:
host = get_from_env("host", "DATABRICKS_HOST", default_host)
default_api_token = context.apiToken if context else None
if api_token is None:
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
if warehouse_id is None and cluster_id is None:
if context:
cluster_id = context.clusterId
else:
raise ValueError(
"Need to provide either 'warehouse_id' or 'cluster_id'."
)
if warehouse_id and cluster_id:
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
if warehouse_id:
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
else:
http_path = f"/sql/protocolv1/o/0/{cluster_id}"
uri = (
f"databricks://token:{api_token}@{host}?"
f"http_path={http_path}&catalog={catalog}&schema={schema}"
)
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
@classmethod
def from_cnosdb(
cls,
url: str = "127.0.0.1:8902",
user: str = "root",
password: str = "",
tenant: str = "cnosdb",
database: str = "public",
) -> SQLDatabase:
"""
Class method to create an SQLDatabase instance from a CnosDB connection.
This method requires the 'cnos-connector' package. If not installed, it
can be added using `pip install cnos-connector`.
Args:
url (str): The HTTP connection host name and port number of the CnosDB
service, excluding "http://" or "https://", with a default value
of "127.0.0.1:8902".
user (str): The username used to connect to the CnosDB service, with a
default value of "root".
password (str): The password of the user connecting to the CnosDB service,
with a default value of "".
tenant (str): The name of the tenant used to connect to the CnosDB service,
with a default value of "cnosdb".
database (str): The name of the database in the CnosDB tenant.
Returns:
SQLDatabase: An instance of SQLDatabase configured with the provided
CnosDB connection details.
"""
try:
from cnosdb_connector import make_cnosdb_langchain_uri
uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
return cls.from_uri(database_uri=uri)
except ImportError:
raise ImportError(
"cnos-connector package not found, please install with"
" `pip install cnos-connector`"
)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return sorted(self._include_tables)
return sorted(self._all_tables - self._ignore_tables)
@deprecated("0.0.1", alternative="get_usable_table_names", removal="1.0")
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
return self.get_usable_table_names()
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
to_reflect = set(all_table_names) - set(metadata_table_names)
if to_reflect:
self._metadata.reflect(
views=self._view_support,
bind=self._engine,
only=list(to_reflect),
schema=self._schema,
)
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
for table in meta_tables:
if self._custom_table_info and table.name in self._custom_table_info:
tables.append(self._custom_table_info[table.name])
continue
# Ignore JSON datatyped columns
for k, v in table.columns.items(): # AttributeError: items in sqlalchemy v1
if type(v.type) is NullType:
table._columns.remove(v)
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
if self._indexes_in_table_info:
table_info += f"\n{self._get_table_indexes(table)}\n"
if self._sample_rows_in_table_info:
table_info += f"\n{self._get_sample_rows(table)}\n"
if has_extra_info:
table_info += "*/"
tables.append(table_info)
tables.sort()
final_str = "\n\n".join(tables)
return final_str
def _get_table_indexes(self, table: Table) -> str:
indexes = self._inspector.get_indexes(table.name)
indexes_formatted = "\n".join(map(_format_index, indexes))
return f"Table Indexes:\n{indexes_formatted}"
def _get_sample_rows(self, table: Table) -> str:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
try:
# get the sample rows
with self._engine.connect() as connection:
sample_rows_result = connection.execute(command) # type: ignore
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _execute(
self,
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[Sequence[Dict[str, Any]], Result]:
"""
Executes SQL command through underlying engine.
If the statement returns no rows, an empty list is returned.
"""
parameters = parameters or {}
execution_options = execution_options or {}
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
"ALTER SESSION SET search_path = %s",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "bigquery":
connection.exec_driver_sql(
"SET @@dataset_id=?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "mssql":
pass
elif self.dialect == "trino":
connection.exec_driver_sql(
"USE ?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "duckdb":
# Unclear which parameterized argument syntax duckdb supports.
# The docs for the duckdb client say they support multiple,
# but `duckdb_engine` seemed to struggle with all of them:
# https://github.com/Mause/duckdb_engine/issues/796
connection.exec_driver_sql(
f"SET search_path TO {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "oracle":
connection.exec_driver_sql(
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "sqlany":
# If anybody using Sybase SQL anywhere database then it should not
# go to else condition. It should be same as mssql.
pass
elif self.dialect == "postgresql": # postgresql
connection.exec_driver_sql(
"SET search_path TO %s",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "hana":
connection.exec_driver_sql(
f"SET SCHEMA {sanitize_schema(self._schema)}",
execution_options=execution_options,
)
if isinstance(command, str):
command = text(command)
elif isinstance(command, Executable):
pass
else:
raise TypeError(f"Query expression has unknown type: {type(command)}")
cursor = connection.execute(
command,
parameters,
execution_options=execution_options,
)
if cursor.returns_rows:
if fetch == "all":
result = [x._asdict() for x in cursor.fetchall()]
elif fetch == "one":
first_result = cursor.fetchone()
result = [] if first_result is None else [first_result._asdict()]
elif fetch == "cursor":
return cursor
else:
raise ValueError(
"Fetch parameter must be either 'one', 'all', or 'cursor'"
)
return result
return []
def run(
self,
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
include_columns: bool = False,
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
result = self._execute(
command, fetch, parameters=parameters, execution_options=execution_options
)
if fetch == "cursor":
return result
res = [
{
column: truncate_word(value, length=self._max_string_length)
for column, value in r.items()
}
for r in result
]
if not include_columns:
res = [tuple(row.values()) for row in res] # type: ignore[misc]
if not res:
return ""
else:
return str(res)
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
def run_no_throw(
self,
command: str,
fetch: Literal["all", "one"] = "all",
include_columns: bool = False,
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
return self.run(
command,
fetch,
parameters=parameters,
execution_options=execution_options,
include_columns=include_columns,
)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
def get_context(self) -> Dict[str, Any]:
"""Return db context that you may want in agent prompt."""
table_names = list(self.get_usable_table_names())
table_info = self.get_table_info_no_throw()
return {"table_info": table_info, "table_names": ", ".join(table_names)}