mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
community[minor]: SQLDatabase Add fetch mode cursor
, query parameters, query by selectable, expose execution options, and documentation (#17191)
- **Description:** Improve `SQLDatabase` adapter component to promote code re-use, see [suggestion](https://github.com/langchain-ai/langchain/pull/16246#pullrequestreview-1846590962). - **Needed by:** GH-16246 - **Addressed to:** @baskaryan, @cbornet ## Details - Add `cursor` fetch mode - Accept SQL query parameters - Accept both `str` and SQLAlchemy selectables as query expression - Expose `execution_options` - Documentation page (notebook) about `SQLDatabase` [^1] See [About SQLDatabase](https://github.com/langchain-ai/langchain/blob/c1c7b763/docs/docs/integrations/tools/sql_database.ipynb). [^1]: Apparently there hasn't been any yet? --------- Co-authored-by: Andreas Motl <andreas.motl@crate.io>
This commit is contained in:
@@ -1,12 +1,21 @@
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.utils import get_from_env
|
||||
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||
from sqlalchemy import (
|
||||
Executable,
|
||||
MetaData,
|
||||
Result,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
@@ -373,67 +382,113 @@ class SQLDatabase:
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
command: str,
|
||||
fetch: Literal["all", "one"] = "all",
|
||||
) -> Sequence[Dict[str, Any]]:
|
||||
command: Union[str, Executable],
|
||||
fetch: Literal["all", "one", "cursor"] = "all",
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[Sequence[Dict[str, Any]], Result]:
|
||||
"""
|
||||
Executes SQL command through underlying engine.
|
||||
|
||||
If the statement returns no rows, an empty list is returned.
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
execution_options = execution_options or {}
|
||||
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
|
||||
if self._schema is not None:
|
||||
if self.dialect == "snowflake":
|
||||
connection.exec_driver_sql(
|
||||
"ALTER SESSION SET search_path = %s", (self._schema,)
|
||||
"ALTER SESSION SET search_path = %s",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "bigquery":
|
||||
connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,))
|
||||
connection.exec_driver_sql(
|
||||
"SET @@dataset_id=?",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "mssql":
|
||||
pass
|
||||
elif self.dialect == "trino":
|
||||
connection.exec_driver_sql("USE ?", (self._schema,))
|
||||
connection.exec_driver_sql(
|
||||
"USE ?",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "duckdb":
|
||||
# Unclear which parameterized argument syntax duckdb supports.
|
||||
# The docs for the duckdb client say they support multiple,
|
||||
# but `duckdb_engine` seemed to struggle with all of them:
|
||||
# https://github.com/Mause/duckdb_engine/issues/796
|
||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||
connection.exec_driver_sql(
|
||||
f"SET search_path TO {self._schema}",
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "oracle":
|
||||
connection.exec_driver_sql(
|
||||
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}"
|
||||
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "sqlany":
|
||||
# If anybody using Sybase SQL anywhere database then it should not
|
||||
# go to else condition. It should be same as mssql.
|
||||
pass
|
||||
elif self.dialect == "postgresql": # postgresql
|
||||
connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
|
||||
connection.exec_driver_sql(
|
||||
"SET search_path TO %s",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
|
||||
if isinstance(command, str):
|
||||
command = text(command)
|
||||
elif isinstance(command, Executable):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"Query expression has unknown type: {type(command)}")
|
||||
cursor = connection.execute(
|
||||
command,
|
||||
parameters,
|
||||
execution_options=execution_options,
|
||||
)
|
||||
|
||||
cursor = connection.execute(text(command))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = [x._asdict() for x in cursor.fetchall()]
|
||||
elif fetch == "one":
|
||||
first_result = cursor.fetchone()
|
||||
result = [] if first_result is None else [first_result._asdict()]
|
||||
elif fetch == "cursor":
|
||||
return cursor
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
raise ValueError(
|
||||
"Fetch parameter must be either 'one', 'all', or 'cursor'"
|
||||
)
|
||||
return result
|
||||
return []
|
||||
|
||||
def run(
|
||||
self,
|
||||
command: str,
|
||||
fetch: Literal["all", "one"] = "all",
|
||||
command: Union[str, Executable],
|
||||
fetch: Literal["all", "one", "cursor"] = "all",
|
||||
include_columns: bool = False,
|
||||
) -> str:
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
"""
|
||||
result = self._execute(command, fetch)
|
||||
result = self._execute(
|
||||
command, fetch, parameters=parameters, execution_options=execution_options
|
||||
)
|
||||
|
||||
if fetch == "cursor":
|
||||
return result
|
||||
|
||||
res = [
|
||||
{
|
||||
@@ -472,7 +527,10 @@ class SQLDatabase:
|
||||
command: str,
|
||||
fetch: Literal["all", "one"] = "all",
|
||||
include_columns: bool = False,
|
||||
) -> str:
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
@@ -481,7 +539,13 @@ class SQLDatabase:
|
||||
If the statement throws an error, the error message is returned.
|
||||
"""
|
||||
try:
|
||||
return self.run(command, fetch, include_columns)
|
||||
return self.run(
|
||||
command,
|
||||
fetch,
|
||||
parameters=parameters,
|
||||
execution_options=execution_options,
|
||||
include_columns=include_columns,
|
||||
)
|
||||
except SQLAlchemyError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
Reference in New Issue
Block a user