mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-12 15:59:56 +00:00
Support Databricks in SQLDatabase (#4702)
This PR adds support for Databricks runtime and Databricks SQL by using [Databricks SQL Connector for Python](https://docs.databricks.com/dev-tools/python-sql-connector.html). As a cloud data platform, accessing Databricks requires a URL as follows `databricks://token:{api_token}@{hostname}?http_path={http_path}&catalog={catalog}&schema={schema}`. **The URL is **complicated** and it may take users a while to figure it out**. Since the fields `api_token`/`hostname`/`http_path` fields are known in the Databricks notebook, I am proposing a new method `from_databricks` to simplify the connection to Databricks. ## In Databricks Notebook After changes, Databricks users only need to specify the `catalog` and `schema` field when using langchain. <img width="881" alt="image" src="https://github.com/hwchase17/langchain/assets/1097932/984b4c57-4c2d-489d-b060-5f4918ef2f37"> ## In Jupyter Notebook The method can be used on the local setup as well: <img width="678" alt="image" src="https://github.com/hwchase17/langchain/assets/1097932/142e8805-a6ef-4919-b28e-9796ca31ef19">
This commit is contained in:
parent
88a3a56c1a
commit
bf5a3c6dec
@ -34,7 +34,7 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Under the hood, LangChain uses SQLAlchemy to connect to SQL databases. The `SQLDatabaseChain` can therefore be used with any SQL dialect supported by SQLAlchemy, such as MS SQL, MySQL, MariaDB, PostgreSQL, Oracle SQL, and SQLite. Please refer to the SQLAlchemy documentation for more information about requirements for connecting to your database. For example, a connection to MySQL requires an appropriate connector such as PyMySQL. A URI for a MySQL connection might look like: `mysql+pymysql://user:pass@some_mysql_db_address/db_name`\n",
|
"Under the hood, LangChain uses SQLAlchemy to connect to SQL databases. The `SQLDatabaseChain` can therefore be used with any SQL dialect supported by SQLAlchemy, such as MS SQL, MySQL, MariaDB, PostgreSQL, Oracle SQL, Databricks and SQLite. Please refer to the SQLAlchemy documentation for more information about requirements for connecting to your database. For example, a connection to MySQL requires an appropriate connector such as PyMySQL. A URI for a MySQL connection might look like: `mysql+pymysql://user:pass@some_mysql_db_address/db_name`. To connect to Databricks, it is recommended to use the handy method `SQLDatabase.from_databricks(catalog, schema, host, api_token, (warehouse_id|cluster_id))`.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This demonstration uses SQLite and the example Chinook database.\n",
|
"This demonstration uses SQLite and the example Chinook database.\n",
|
||||||
"To set it up, follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
|
"To set it up, follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
|
||||||
|
@ -17,6 +17,8 @@ from sqlalchemy.engine import Engine
|
|||||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||||
from sqlalchemy.schema import CreateTable
|
from sqlalchemy.schema import CreateTable
|
||||||
|
|
||||||
|
from langchain import utils
|
||||||
|
|
||||||
|
|
||||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||||
return (
|
return (
|
||||||
@ -110,6 +112,105 @@ class SQLDatabase:
|
|||||||
_engine_args = engine_args or {}
|
_engine_args = engine_args or {}
|
||||||
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_databricks(
|
||||||
|
cls,
|
||||||
|
catalog: str,
|
||||||
|
schema: str,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
api_token: Optional[str] = None,
|
||||||
|
warehouse_id: Optional[str] = None,
|
||||||
|
cluster_id: Optional[str] = None,
|
||||||
|
engine_args: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> SQLDatabase:
|
||||||
|
"""
|
||||||
|
Class method to create an SQLDatabase instance from a Databricks connection.
|
||||||
|
This method requires the 'databricks-sql-connector' package. If not installed,
|
||||||
|
it can be added using `pip install databricks-sql-connector`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
catalog (str): The catalog name in the Databricks database.
|
||||||
|
schema (str): The schema name in the catalog.
|
||||||
|
host (Optional[str]): The Databricks workspace hostname, excluding
|
||||||
|
'https://' part. If not provided, it attempts to fetch from the
|
||||||
|
environment variable 'DATABRICKS_HOST'. If still unavailable and if
|
||||||
|
running in a Databricks notebook, it defaults to the current workspace
|
||||||
|
hostname. Defaults to None.
|
||||||
|
api_token (Optional[str]): The Databricks personal access token for
|
||||||
|
accessing the Databricks SQL warehouse or the cluster. If not provided,
|
||||||
|
it attempts to fetch from 'DATABRICKS_API_TOKEN'. If still unavailable
|
||||||
|
and running in a Databricks notebook, a temporary token for the current
|
||||||
|
user is generated. Defaults to None.
|
||||||
|
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
|
||||||
|
provided, the method configures the connection to use this warehouse.
|
||||||
|
Cannot be used with 'cluster_id'. Defaults to None.
|
||||||
|
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
|
||||||
|
provided, the method configures the connection to use this cluster.
|
||||||
|
Cannot be used with 'warehouse_id'. If running in a Databricks notebook
|
||||||
|
and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
|
||||||
|
cluster the notebook is attached to. Defaults to None.
|
||||||
|
engine_args (Optional[dict]): The arguments to be used when connecting
|
||||||
|
Databricks. Defaults to None.
|
||||||
|
**kwargs (Any): Additional keyword arguments for the `from_uri` method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLDatabase: An instance of SQLDatabase configured with the provided
|
||||||
|
Databricks connection details.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If 'databricks-sql-connector' is not found, or if both
|
||||||
|
'warehouse_id' and 'cluster_id' are provided, or if neither
|
||||||
|
'warehouse_id' nor 'cluster_id' are provided and it's not executing
|
||||||
|
inside a Databricks notebook.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from databricks import sql # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"databricks-sql-connector package not found, please install with"
|
||||||
|
" `pip install databricks-sql-connector`"
|
||||||
|
)
|
||||||
|
context = None
|
||||||
|
try:
|
||||||
|
from dbruntime.databricks_repl_context import get_context
|
||||||
|
|
||||||
|
context = get_context()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
default_host = context.browserHostName if context else None
|
||||||
|
if host is None:
|
||||||
|
host = utils.get_from_env("host", "DATABRICKS_HOST", default_host)
|
||||||
|
|
||||||
|
default_api_token = context.apiToken if context else None
|
||||||
|
if api_token is None:
|
||||||
|
api_token = utils.get_from_env(
|
||||||
|
"api_token", "DATABRICKS_API_TOKEN", default_api_token
|
||||||
|
)
|
||||||
|
|
||||||
|
if warehouse_id is None and cluster_id is None:
|
||||||
|
if context:
|
||||||
|
cluster_id = context.clusterId
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Need to provide either 'warehouse_id' or 'cluster_id'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if warehouse_id and cluster_id:
|
||||||
|
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
|
||||||
|
|
||||||
|
if warehouse_id:
|
||||||
|
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
|
||||||
|
else:
|
||||||
|
http_path = f"/sql/protocolv1/o/0/{cluster_id}"
|
||||||
|
|
||||||
|
uri = (
|
||||||
|
f"databricks://token:{api_token}@{host}?"
|
||||||
|
f"http_path={http_path}&catalog={catalog}&schema={schema}"
|
||||||
|
)
|
||||||
|
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dialect(self) -> str:
|
def dialect(self) -> str:
|
||||||
"""Return string representation of dialect to use."""
|
"""Return string representation of dialect to use."""
|
||||||
|
Loading…
Reference in New Issue
Block a user