community[minor]: SQLDatabase Add fetch mode cursor, query parameters, query by selectable, expose execution options, and documentation (#17191)

- **Description:** Improve `SQLDatabase` adapter component to promote
code re-use, see
[suggestion](https://github.com/langchain-ai/langchain/pull/16246#pullrequestreview-1846590962).
  - **Needed by:** GH-16246
  - **Addressed to:** @baskaryan, @cbornet 

## Details
- Add `cursor` fetch mode
- Accept SQL query parameters
- Accept both `str` and SQLAlchemy selectables as query expression
- Expose `execution_options`
- Documentation page (notebook) about `SQLDatabase` [^1]
See [About
SQLDatabase](https://github.com/langchain-ai/langchain/blob/c1c7b763/docs/docs/integrations/tools/sql_database.ipynb).

[^1]: Apparently there hasn't been any yet?

---------

Co-authored-by: Andreas Motl <andreas.motl@crate.io>
This commit is contained in:
Eugene Yurtsev
2024-02-07 22:23:43 -05:00
committed by GitHub
parent 7e4b676d53
commit 780e84ae79
6 changed files with 600 additions and 26 deletions

View File

@@ -1,6 +1,8 @@
# flake8: noqa
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Sequence, Type, Union
from sqlalchemy import Result
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@@ -42,7 +44,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""Execute the query, return the results or an error message."""
return self.db.run_no_throw(query)

View File

@@ -1,12 +1,21 @@
"""SQLAlchemy wrapper around a database."""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
import sqlalchemy
from langchain_core._api import deprecated
from langchain_core.utils import get_from_env
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy import (
Executable,
MetaData,
Result,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
@@ -373,67 +382,113 @@ class SQLDatabase:
def _execute(
self,
command: str,
fetch: Literal["all", "one"] = "all",
) -> Sequence[Dict[str, Any]]:
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[Sequence[Dict[str, Any]], Result]:
"""
Executes SQL command through underlying engine.
If the statement returns no rows, an empty list is returned.
"""
parameters = parameters or {}
execution_options = execution_options or {}
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
"ALTER SESSION SET search_path = %s", (self._schema,)
"ALTER SESSION SET search_path = %s",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "bigquery":
connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,))
connection.exec_driver_sql(
"SET @@dataset_id=?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "mssql":
pass
elif self.dialect == "trino":
connection.exec_driver_sql("USE ?", (self._schema,))
connection.exec_driver_sql(
"USE ?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "duckdb":
# Unclear which parameterized argument syntax duckdb supports.
# The docs for the duckdb client say they support multiple,
# but `duckdb_engine` seemed to struggle with all of them:
# https://github.com/Mause/duckdb_engine/issues/796
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
connection.exec_driver_sql(
f"SET search_path TO {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "oracle":
connection.exec_driver_sql(
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}"
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "sqlany":
# If anybody using Sybase SQL anywhere database then it should not
# go to else condition. It should be same as mssql.
pass
elif self.dialect == "postgresql": # postgresql
connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
connection.exec_driver_sql(
"SET search_path TO %s",
(self._schema,),
execution_options=execution_options,
)
if isinstance(command, str):
command = text(command)
elif isinstance(command, Executable):
pass
else:
raise TypeError(f"Query expression has unknown type: {type(command)}")
cursor = connection.execute(
command,
parameters,
execution_options=execution_options,
)
cursor = connection.execute(text(command))
if cursor.returns_rows:
if fetch == "all":
result = [x._asdict() for x in cursor.fetchall()]
elif fetch == "one":
first_result = cursor.fetchone()
result = [] if first_result is None else [first_result._asdict()]
elif fetch == "cursor":
return cursor
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
raise ValueError(
"Fetch parameter must be either 'one', 'all', or 'cursor'"
)
return result
return []
def run(
self,
command: str,
fetch: Literal["all", "one"] = "all",
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
include_columns: bool = False,
) -> str:
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
result = self._execute(command, fetch)
result = self._execute(
command, fetch, parameters=parameters, execution_options=execution_options
)
if fetch == "cursor":
return result
res = [
{
@@ -472,7 +527,10 @@ class SQLDatabase:
command: str,
fetch: Literal["all", "one"] = "all",
include_columns: bool = False,
) -> str:
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
@@ -481,7 +539,13 @@ class SQLDatabase:
If the statement throws an error, the error message is returned.
"""
try:
return self.run(command, fetch, include_columns)
return self.run(
command,
fetch,
parameters=parameters,
execution_options=execution_options,
include_columns=include_columns,
)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"

View File

@@ -173,7 +173,7 @@ class SQLDatabaseChain(Chain):
sql_cmd = checked_sql_command
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
# If return direct, we just set the final result equal to
# the result of the sql query result, otherwise try to get a human readable
# final answer

View File

@@ -78,6 +78,7 @@ class VectorSQLRetrieveAllOutputParser(VectorSQLOutputParser):
def get_result_from_sqldb(db: SQLDatabase, cmd: str) -> Sequence[Dict[str, Any]]:
result = db._execute(cmd, fetch="all")
assert isinstance(result, Sequence)
return result

View File

@@ -1,16 +1,19 @@
# flake8: noqa=E501
# flake8: noqa: E501
"""Test SQL database wrapper."""
import pytest
import sqlalchemy as sa
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
from sqlalchemy import (
Column,
Integer,
MetaData,
Result,
String,
Table,
Text,
create_engine,
insert,
select,
)
metadata_obj = MetaData()
@@ -108,8 +111,8 @@ def test_table_info_w_sample_rows() -> None:
assert sorted(output.split()) == sorted(expected_output.split())
def test_sql_database_run() -> None:
"""Test that commands can be run successfully and returned in correct format."""
def test_sql_database_run_fetch_all() -> None:
"""Verify running SQL expressions returning results as strings."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(
@@ -131,6 +134,52 @@ def test_sql_database_run() -> None:
assert full_output == expected_full_output
def test_sql_database_run_fetch_result() -> None:
"""Verify running SQL expressions returning results as SQLAlchemy `Result` instances."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=17, user_name="hwchase")
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = "select user_id, user_name, user_bio from user where user_id = 17"
result = db.run(command, fetch="cursor", include_columns=True)
expected = [{"user_id": 17, "user_name": "hwchase", "user_bio": None}]
assert isinstance(result, Result)
assert result.mappings().fetchall() == expected
def test_sql_database_run_with_parameters() -> None:
"""Verify running SQL expressions with query parameters."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=17, user_name="hwchase")
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = "select user_id, user_name, user_bio from user where user_id = :user_id"
full_output = db.run(command, parameters={"user_id": 17}, include_columns=True)
expected_full_output = "[{'user_id': 17, 'user_name': 'hwchase', 'user_bio': None}]"
assert full_output == expected_full_output
def test_sql_database_run_sqlalchemy_selectable() -> None:
"""Verify running SQL expressions using SQLAlchemy selectable."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=17, user_name="hwchase")
with engine.begin() as conn:
conn.execute(stmt)
db = SQLDatabase(engine)
command = select(user).where(user.c.user_id == 17)
full_output = db.run(command, include_columns=True)
expected_full_output = "[{'user_id': 17, 'user_name': 'hwchase', 'user_bio': None}]"
assert full_output == expected_full_output
def test_sql_database_run_update() -> None:
"""Test commands which return no rows return an empty string."""
engine = create_engine("sqlite:///:memory:")
@@ -145,6 +194,24 @@ def test_sql_database_run_update() -> None:
assert output == expected_output
def test_sql_database_schema_translate_map() -> None:
"""Verify using statement-specific execution options."""
engine = create_engine("sqlite:///:memory:")
db = SQLDatabase(engine)
# Define query using SQLAlchemy selectable.
command = select(user).where(user.c.user_id == 17)
# Define statement-specific execution options.
execution_options = {"schema_translate_map": {None: "bar"}}
# Verify the schema translation is applied.
with pytest.raises(sa.exc.OperationalError) as ex:
db.run(command, execution_options=execution_options, fetch="cursor")
assert ex.match("no such table: bar.user")
def test_truncate_word() -> None:
assert truncate_word("Hello World", length=5) == "He..."
assert truncate_word("Hello World", length=0) == "Hello World"