mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 12:25:37 +00:00
This PR includes support for HANA dialect in SQLDatabase, which is a wrapper class for SQLAlchemy. Currently, it is unable to set schema name when using HANA DB with Langchain. And, it does not show any message to user so that it makes hard for user to figure out why the SQL does not work as expected. Here is the reference document for HANA DB to set schema for the session. - [SET SCHEMA Statement (Session Management)](https://help.sap.com/docs/SAP_HANA_PLATFORM/4fe29514fd584807ac9f2a04f6754767/20fd550375191014b886a338afb4cd5f.html)
598 lines
23 KiB
Python
598 lines
23 KiB
Python
"""SQLAlchemy wrapper around a database."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
|
|
|
|
import sqlalchemy
|
|
from langchain_core._api import deprecated
|
|
from langchain_core.utils import get_from_env
|
|
from sqlalchemy import (
|
|
MetaData,
|
|
Table,
|
|
create_engine,
|
|
inspect,
|
|
select,
|
|
text,
|
|
)
|
|
from sqlalchemy.engine import URL, Engine, Result
|
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
|
from sqlalchemy.schema import CreateTable
|
|
from sqlalchemy.sql.expression import Executable
|
|
from sqlalchemy.types import NullType
|
|
|
|
|
|
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
|
return (
|
|
f"Name: {index['name']}, Unique: {index['unique']},"
|
|
f" Columns: {str(index['column_names'])}"
|
|
)
|
|
|
|
|
|
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
|
|
"""
|
|
Truncate a string to a certain number of words, based on the max string
|
|
length.
|
|
"""
|
|
|
|
if not isinstance(content, str) or length <= 0:
|
|
return content
|
|
|
|
if len(content) <= length:
|
|
return content
|
|
|
|
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
|
|
|
|
|
def sanitize_schema(schema: str) -> str:
|
|
"""Sanitize a schema name to only contain letters, digits, and underscores."""
|
|
if not re.match(r"^[a-zA-Z0-9_]+$", schema):
|
|
raise ValueError(
|
|
f"Schema name '{schema}' contains invalid characters. "
|
|
"Schema names must contain only letters, digits, and underscores."
|
|
)
|
|
return schema
|
|
|
|
|
|
class SQLDatabase:
|
|
"""SQLAlchemy wrapper around a database."""
|
|
|
|
def __init__(
|
|
self,
|
|
engine: Engine,
|
|
schema: Optional[str] = None,
|
|
metadata: Optional[MetaData] = None,
|
|
ignore_tables: Optional[List[str]] = None,
|
|
include_tables: Optional[List[str]] = None,
|
|
sample_rows_in_table_info: int = 3,
|
|
indexes_in_table_info: bool = False,
|
|
custom_table_info: Optional[dict] = None,
|
|
view_support: bool = False,
|
|
max_string_length: int = 300,
|
|
lazy_table_reflection: bool = False,
|
|
):
|
|
"""Create engine from database URI."""
|
|
self._engine = engine
|
|
self._schema = schema
|
|
if include_tables and ignore_tables:
|
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
|
|
|
self._inspector = inspect(self._engine)
|
|
|
|
# including view support by adding the views as well as tables to the all
|
|
# tables list if view_support is True
|
|
self._all_tables = set(
|
|
list(self._inspector.get_table_names(schema=schema))
|
|
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
|
|
)
|
|
|
|
self._include_tables = set(include_tables) if include_tables else set()
|
|
if self._include_tables:
|
|
missing_tables = self._include_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"include_tables {missing_tables} not found in database"
|
|
)
|
|
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
|
if self._ignore_tables:
|
|
missing_tables = self._ignore_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"ignore_tables {missing_tables} not found in database"
|
|
)
|
|
usable_tables = self.get_usable_table_names()
|
|
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
|
|
|
if not isinstance(sample_rows_in_table_info, int):
|
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
|
|
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
|
self._indexes_in_table_info = indexes_in_table_info
|
|
|
|
self._custom_table_info = custom_table_info
|
|
if self._custom_table_info:
|
|
if not isinstance(self._custom_table_info, dict):
|
|
raise TypeError(
|
|
"table_info must be a dictionary with table names as keys and the "
|
|
"desired table info as values"
|
|
)
|
|
# only keep the tables that are also present in the database
|
|
intersection = set(self._custom_table_info).intersection(self._all_tables)
|
|
self._custom_table_info = dict(
|
|
(table, self._custom_table_info[table])
|
|
for table in self._custom_table_info
|
|
if table in intersection
|
|
)
|
|
|
|
self._max_string_length = max_string_length
|
|
self._view_support = view_support
|
|
|
|
self._metadata = metadata or MetaData()
|
|
if not lazy_table_reflection:
|
|
# including view support if view_support = true
|
|
self._metadata.reflect(
|
|
views=view_support,
|
|
bind=self._engine,
|
|
only=list(self._usable_tables),
|
|
schema=self._schema,
|
|
)
|
|
|
|
@classmethod
|
|
def from_uri(
|
|
cls,
|
|
database_uri: Union[str, URL],
|
|
engine_args: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> SQLDatabase:
|
|
"""Construct a SQLAlchemy engine from URI."""
|
|
_engine_args = engine_args or {}
|
|
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
|
|
|
@classmethod
|
|
@deprecated(
|
|
"0.3.18",
|
|
message="For performing structured retrieval using Databricks SQL, "
|
|
"see the latest best practices and recommended APIs at "
|
|
"https://docs.unitycatalog.io/ai/integrations/langchain/ " # noqa: E501
|
|
"instead",
|
|
removal="1.0",
|
|
)
|
|
def from_databricks(
|
|
cls,
|
|
catalog: str,
|
|
schema: str,
|
|
host: Optional[str] = None,
|
|
api_token: Optional[str] = None,
|
|
warehouse_id: Optional[str] = None,
|
|
cluster_id: Optional[str] = None,
|
|
engine_args: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> SQLDatabase:
|
|
"""
|
|
Class method to create an SQLDatabase instance from a Databricks connection.
|
|
This method requires the 'databricks-sql-connector' package. If not installed,
|
|
it can be added using `pip install databricks-sql-connector`.
|
|
|
|
Args:
|
|
catalog (str): The catalog name in the Databricks database.
|
|
schema (str): The schema name in the catalog.
|
|
host (Optional[str]): The Databricks workspace hostname, excluding
|
|
'https://' part. If not provided, it attempts to fetch from the
|
|
environment variable 'DATABRICKS_HOST'. If still unavailable and if
|
|
running in a Databricks notebook, it defaults to the current workspace
|
|
hostname. Defaults to None.
|
|
api_token (Optional[str]): The Databricks personal access token for
|
|
accessing the Databricks SQL warehouse or the cluster. If not provided,
|
|
it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable
|
|
and running in a Databricks notebook, a temporary token for the current
|
|
user is generated. Defaults to None.
|
|
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
|
|
provided, the method configures the connection to use this warehouse.
|
|
Cannot be used with 'cluster_id'. Defaults to None.
|
|
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
|
|
provided, the method configures the connection to use this cluster.
|
|
Cannot be used with 'warehouse_id'. If running in a Databricks notebook
|
|
and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
|
|
cluster the notebook is attached to. Defaults to None.
|
|
engine_args (Optional[dict]): The arguments to be used when connecting
|
|
Databricks. Defaults to None.
|
|
**kwargs (Any): Additional keyword arguments for the `from_uri` method.
|
|
|
|
Returns:
|
|
SQLDatabase: An instance of SQLDatabase configured with the provided
|
|
Databricks connection details.
|
|
|
|
Raises:
|
|
ValueError: If 'databricks-sql-connector' is not found, or if both
|
|
'warehouse_id' and 'cluster_id' are provided, or if neither
|
|
'warehouse_id' nor 'cluster_id' are provided and it's not executing
|
|
inside a Databricks notebook.
|
|
"""
|
|
try:
|
|
from databricks import sql # noqa: F401
|
|
except ImportError:
|
|
raise ImportError(
|
|
"databricks-sql-connector package not found, please install with"
|
|
" `pip install databricks-sql-connector`"
|
|
)
|
|
context = None
|
|
try:
|
|
from dbruntime.databricks_repl_context import get_context
|
|
|
|
context = get_context()
|
|
default_host = context.browserHostName
|
|
except (ImportError, AttributeError):
|
|
default_host = None
|
|
|
|
if host is None:
|
|
host = get_from_env("host", "DATABRICKS_HOST", default_host)
|
|
|
|
default_api_token = context.apiToken if context else None
|
|
if api_token is None:
|
|
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
|
|
|
|
if warehouse_id is None and cluster_id is None:
|
|
if context:
|
|
cluster_id = context.clusterId
|
|
else:
|
|
raise ValueError(
|
|
"Need to provide either 'warehouse_id' or 'cluster_id'."
|
|
)
|
|
|
|
if warehouse_id and cluster_id:
|
|
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
|
|
|
|
if warehouse_id:
|
|
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
|
|
else:
|
|
http_path = f"/sql/protocolv1/o/0/{cluster_id}"
|
|
|
|
uri = (
|
|
f"databricks://token:{api_token}@{host}?"
|
|
f"http_path={http_path}&catalog={catalog}&schema={schema}"
|
|
)
|
|
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
|
|
|
|
@classmethod
|
|
def from_cnosdb(
|
|
cls,
|
|
url: str = "127.0.0.1:8902",
|
|
user: str = "root",
|
|
password: str = "",
|
|
tenant: str = "cnosdb",
|
|
database: str = "public",
|
|
) -> SQLDatabase:
|
|
"""
|
|
Class method to create an SQLDatabase instance from a CnosDB connection.
|
|
This method requires the 'cnos-connector' package. If not installed, it
|
|
can be added using `pip install cnos-connector`.
|
|
|
|
Args:
|
|
url (str): The HTTP connection host name and port number of the CnosDB
|
|
service, excluding "http://" or "https://", with a default value
|
|
of "127.0.0.1:8902".
|
|
user (str): The username used to connect to the CnosDB service, with a
|
|
default value of "root".
|
|
password (str): The password of the user connecting to the CnosDB service,
|
|
with a default value of "".
|
|
tenant (str): The name of the tenant used to connect to the CnosDB service,
|
|
with a default value of "cnosdb".
|
|
database (str): The name of the database in the CnosDB tenant.
|
|
|
|
Returns:
|
|
SQLDatabase: An instance of SQLDatabase configured with the provided
|
|
CnosDB connection details.
|
|
"""
|
|
try:
|
|
from cnosdb_connector import make_cnosdb_langchain_uri
|
|
|
|
uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
|
|
return cls.from_uri(database_uri=uri)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"cnos-connector package not found, please install with"
|
|
" `pip install cnos-connector`"
|
|
)
|
|
|
|
@property
|
|
def dialect(self) -> str:
|
|
"""Return string representation of dialect to use."""
|
|
return self._engine.dialect.name
|
|
|
|
def get_usable_table_names(self) -> Iterable[str]:
|
|
"""Get names of tables available."""
|
|
if self._include_tables:
|
|
return sorted(self._include_tables)
|
|
return sorted(self._all_tables - self._ignore_tables)
|
|
|
|
@deprecated("0.0.1", alternative="get_usable_table_names", removal="1.0")
|
|
def get_table_names(self) -> Iterable[str]:
|
|
"""Get names of tables available."""
|
|
return self.get_usable_table_names()
|
|
|
|
@property
|
|
def table_info(self) -> str:
|
|
"""Information about all tables in the database."""
|
|
return self.get_table_info()
|
|
|
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
"""Get information about specified tables.
|
|
|
|
Follows best practices as specified in: Rajkumar et al, 2022
|
|
(https://arxiv.org/abs/2204.00498)
|
|
|
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
appended to each table description. This can increase performance as
|
|
demonstrated in the paper.
|
|
"""
|
|
all_table_names = self.get_usable_table_names()
|
|
if table_names is not None:
|
|
missing_tables = set(table_names).difference(all_table_names)
|
|
if missing_tables:
|
|
raise ValueError(f"table_names {missing_tables} not found in database")
|
|
all_table_names = table_names
|
|
|
|
metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
|
|
to_reflect = set(all_table_names) - set(metadata_table_names)
|
|
if to_reflect:
|
|
self._metadata.reflect(
|
|
views=self._view_support,
|
|
bind=self._engine,
|
|
only=list(to_reflect),
|
|
schema=self._schema,
|
|
)
|
|
|
|
meta_tables = [
|
|
tbl
|
|
for tbl in self._metadata.sorted_tables
|
|
if tbl.name in set(all_table_names)
|
|
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
|
]
|
|
|
|
tables = []
|
|
for table in meta_tables:
|
|
if self._custom_table_info and table.name in self._custom_table_info:
|
|
tables.append(self._custom_table_info[table.name])
|
|
continue
|
|
|
|
# Ignore JSON datatyped columns
|
|
for k, v in table.columns.items(): # AttributeError: items in sqlalchemy v1
|
|
if type(v.type) is NullType:
|
|
table._columns.remove(v)
|
|
|
|
# add create table command
|
|
create_table = str(CreateTable(table).compile(self._engine))
|
|
table_info = f"{create_table.rstrip()}"
|
|
has_extra_info = (
|
|
self._indexes_in_table_info or self._sample_rows_in_table_info
|
|
)
|
|
if has_extra_info:
|
|
table_info += "\n\n/*"
|
|
if self._indexes_in_table_info:
|
|
table_info += f"\n{self._get_table_indexes(table)}\n"
|
|
if self._sample_rows_in_table_info:
|
|
table_info += f"\n{self._get_sample_rows(table)}\n"
|
|
if has_extra_info:
|
|
table_info += "*/"
|
|
tables.append(table_info)
|
|
tables.sort()
|
|
final_str = "\n\n".join(tables)
|
|
return final_str
|
|
|
|
def _get_table_indexes(self, table: Table) -> str:
|
|
indexes = self._inspector.get_indexes(table.name)
|
|
indexes_formatted = "\n".join(map(_format_index, indexes))
|
|
return f"Table Indexes:\n{indexes_formatted}"
|
|
|
|
def _get_sample_rows(self, table: Table) -> str:
|
|
# build the select command
|
|
command = select(table).limit(self._sample_rows_in_table_info)
|
|
|
|
# save the columns in string format
|
|
columns_str = "\t".join([col.name for col in table.columns])
|
|
|
|
try:
|
|
# get the sample rows
|
|
with self._engine.connect() as connection:
|
|
sample_rows_result = connection.execute(command) # type: ignore
|
|
# shorten values in the sample rows
|
|
sample_rows = list(
|
|
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
|
)
|
|
|
|
# save the sample rows in string format
|
|
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
|
|
|
# in some dialects when there are no rows in the table a
|
|
# 'ProgrammingError' is returned
|
|
except ProgrammingError:
|
|
sample_rows_str = ""
|
|
|
|
return (
|
|
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
|
f"{columns_str}\n"
|
|
f"{sample_rows_str}"
|
|
)
|
|
|
|
def _execute(
|
|
self,
|
|
command: Union[str, Executable],
|
|
fetch: Literal["all", "one", "cursor"] = "all",
|
|
*,
|
|
parameters: Optional[Dict[str, Any]] = None,
|
|
execution_options: Optional[Dict[str, Any]] = None,
|
|
) -> Union[Sequence[Dict[str, Any]], Result]:
|
|
"""
|
|
Executes SQL command through underlying engine.
|
|
|
|
If the statement returns no rows, an empty list is returned.
|
|
"""
|
|
parameters = parameters or {}
|
|
execution_options = execution_options or {}
|
|
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
|
|
if self._schema is not None:
|
|
if self.dialect == "snowflake":
|
|
connection.exec_driver_sql(
|
|
"ALTER SESSION SET search_path = %s",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "bigquery":
|
|
connection.exec_driver_sql(
|
|
"SET @@dataset_id=?",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "mssql":
|
|
pass
|
|
elif self.dialect == "trino":
|
|
connection.exec_driver_sql(
|
|
"USE ?",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "duckdb":
|
|
# Unclear which parameterized argument syntax duckdb supports.
|
|
# The docs for the duckdb client say they support multiple,
|
|
# but `duckdb_engine` seemed to struggle with all of them:
|
|
# https://github.com/Mause/duckdb_engine/issues/796
|
|
connection.exec_driver_sql(
|
|
f"SET search_path TO {self._schema}",
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "oracle":
|
|
connection.exec_driver_sql(
|
|
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "sqlany":
|
|
# If anybody using Sybase SQL anywhere database then it should not
|
|
# go to else condition. It should be same as mssql.
|
|
pass
|
|
elif self.dialect == "postgresql": # postgresql
|
|
connection.exec_driver_sql(
|
|
"SET search_path TO %s",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "hana":
|
|
connection.exec_driver_sql(
|
|
f"SET SCHEMA {sanitize_schema(self._schema)}",
|
|
execution_options=execution_options,
|
|
)
|
|
|
|
if isinstance(command, str):
|
|
command = text(command)
|
|
elif isinstance(command, Executable):
|
|
pass
|
|
else:
|
|
raise TypeError(f"Query expression has unknown type: {type(command)}")
|
|
cursor = connection.execute(
|
|
command,
|
|
parameters,
|
|
execution_options=execution_options,
|
|
)
|
|
|
|
if cursor.returns_rows:
|
|
if fetch == "all":
|
|
result = [x._asdict() for x in cursor.fetchall()]
|
|
elif fetch == "one":
|
|
first_result = cursor.fetchone()
|
|
result = [] if first_result is None else [first_result._asdict()]
|
|
elif fetch == "cursor":
|
|
return cursor
|
|
else:
|
|
raise ValueError(
|
|
"Fetch parameter must be either 'one', 'all', or 'cursor'"
|
|
)
|
|
return result
|
|
return []
|
|
|
|
def run(
|
|
self,
|
|
command: Union[str, Executable],
|
|
fetch: Literal["all", "one", "cursor"] = "all",
|
|
include_columns: bool = False,
|
|
*,
|
|
parameters: Optional[Dict[str, Any]] = None,
|
|
execution_options: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
|
"""Execute a SQL command and return a string representing the results.
|
|
|
|
If the statement returns rows, a string of the results is returned.
|
|
If the statement returns no rows, an empty string is returned.
|
|
"""
|
|
result = self._execute(
|
|
command, fetch, parameters=parameters, execution_options=execution_options
|
|
)
|
|
|
|
if fetch == "cursor":
|
|
return result
|
|
|
|
res = [
|
|
{
|
|
column: truncate_word(value, length=self._max_string_length)
|
|
for column, value in r.items()
|
|
}
|
|
for r in result
|
|
]
|
|
|
|
if not include_columns:
|
|
res = [tuple(row.values()) for row in res] # type: ignore[misc]
|
|
|
|
if not res:
|
|
return ""
|
|
else:
|
|
return str(res)
|
|
|
|
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
|
"""Get information about specified tables.
|
|
|
|
Follows best practices as specified in: Rajkumar et al, 2022
|
|
(https://arxiv.org/abs/2204.00498)
|
|
|
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
appended to each table description. This can increase performance as
|
|
demonstrated in the paper.
|
|
"""
|
|
try:
|
|
return self.get_table_info(table_names)
|
|
except ValueError as e:
|
|
"""Format the error message"""
|
|
return f"Error: {e}"
|
|
|
|
def run_no_throw(
|
|
self,
|
|
command: str,
|
|
fetch: Literal["all", "one"] = "all",
|
|
include_columns: bool = False,
|
|
*,
|
|
parameters: Optional[Dict[str, Any]] = None,
|
|
execution_options: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
|
"""Execute a SQL command and return a string representing the results.
|
|
|
|
If the statement returns rows, a string of the results is returned.
|
|
If the statement returns no rows, an empty string is returned.
|
|
|
|
If the statement throws an error, the error message is returned.
|
|
"""
|
|
try:
|
|
return self.run(
|
|
command,
|
|
fetch,
|
|
parameters=parameters,
|
|
execution_options=execution_options,
|
|
include_columns=include_columns,
|
|
)
|
|
except SQLAlchemyError as e:
|
|
"""Format the error message"""
|
|
return f"Error: {e}"
|
|
|
|
def get_context(self) -> Dict[str, Any]:
|
|
"""Return db context that you may want in agent prompt."""
|
|
table_names = list(self.get_usable_table_names())
|
|
table_info = self.get_table_info_no_throw()
|
|
return {"table_info": table_info, "table_names": ", ".join(table_names)}
|