community[minor]: SQLDatabase Add fetch mode cursor, query parameters, query by selectable, expose execution options, and documentation (#17191)

- **Description:** Improve `SQLDatabase` adapter component to promote
code re-use, see
[suggestion](https://github.com/langchain-ai/langchain/pull/16246#pullrequestreview-1846590962).
  - **Needed by:** GH-16246
  - **Addressed to:** @baskaryan, @cbornet 

## Details
- Add `cursor` fetch mode
- Accept SQL query parameters
- Accept both `str` and SQLAlchemy selectables as query expression
- Expose `execution_options`
- Documentation page (notebook) about `SQLDatabase` [^1]
See [About
SQLDatabase](https://github.com/langchain-ai/langchain/blob/c1c7b763/docs/docs/integrations/tools/sql_database.ipynb).

[^1]: Apparently there hasn't been any yet?

---------

Co-authored-by: Andreas Motl <andreas.motl@crate.io>
This commit is contained in:
Eugene Yurtsev
2024-02-07 22:23:43 -05:00
committed by GitHub
parent 7e4b676d53
commit 780e84ae79
6 changed files with 600 additions and 26 deletions

View File

@@ -1,12 +1,21 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
import sqlalchemy
from langchain_core._api import deprecated
from langchain_core.utils import get_from_env
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy import (
Executable,
MetaData,
Result,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
@@ -373,67 +382,113 @@ class SQLDatabase:
def _execute(
self,
command: str,
fetch: Literal["all", "one"] = "all",
) -> Sequence[Dict[str, Any]]:
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[Sequence[Dict[str, Any]], Result]:
"""
Executes SQL command through underlying engine.
If the statement returns no rows, an empty list is returned.
"""
parameters = parameters or {}
execution_options = execution_options or {}
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
"ALTER SESSION SET search_path = %s", (self._schema,)
"ALTER SESSION SET search_path = %s",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "bigquery":
connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,))
connection.exec_driver_sql(
"SET @@dataset_id=?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "mssql":
pass
elif self.dialect == "trino":
connection.exec_driver_sql("USE ?", (self._schema,))
connection.exec_driver_sql(
"USE ?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "duckdb":
# Unclear which parameterized argument syntax duckdb supports.
# The docs for the duckdb client say they support multiple,
# but `duckdb_engine` seemed to struggle with all of them:
# https://github.com/Mause/duckdb_engine/issues/796
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
connection.exec_driver_sql(
f"SET search_path TO {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "oracle":
connection.exec_driver_sql(
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}"
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "sqlany":
# If anybody using Sybase SQL anywhere database then it should not
# go to else condition. It should be same as mssql.
pass
elif self.dialect == "postgresql": # postgresql
connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
connection.exec_driver_sql(
"SET search_path TO %s",
(self._schema,),
execution_options=execution_options,
)
if isinstance(command, str):
command = text(command)
elif isinstance(command, Executable):
pass
else:
raise TypeError(f"Query expression has unknown type: {type(command)}")
cursor = connection.execute(
command,
parameters,
execution_options=execution_options,
)
cursor = connection.execute(text(command))
if cursor.returns_rows:
if fetch == "all":
result = [x._asdict() for x in cursor.fetchall()]
elif fetch == "one":
first_result = cursor.fetchone()
result = [] if first_result is None else [first_result._asdict()]
elif fetch == "cursor":
return cursor
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
raise ValueError(
"Fetch parameter must be either 'one', 'all', or 'cursor'"
)
return result
return []
def run(
self,
command: str,
fetch: Literal["all", "one"] = "all",
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
include_columns: bool = False,
) -> str:
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
result = self._execute(command, fetch)
result = self._execute(
command, fetch, parameters=parameters, execution_options=execution_options
)
if fetch == "cursor":
return result
res = [
{
@@ -472,7 +527,10 @@ class SQLDatabase:
command: str,
fetch: Literal["all", "one"] = "all",
include_columns: bool = False,
) -> str:
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
@@ -481,7 +539,13 @@ class SQLDatabase:
If the statement throws an error, the error message is returned.
"""
try:
return self.run(command, fetch, include_columns)
return self.run(
command,
fetch,
parameters=parameters,
execution_options=execution_options,
include_columns=include_columns,
)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"