mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
community[minor]: SQLDatabase Add fetch mode cursor
, query parameters, query by selectable, expose execution options, and documentation (#17191)
- **Description:** Improve `SQLDatabase` adapter component to promote code re-use, see [suggestion](https://github.com/langchain-ai/langchain/pull/16246#pullrequestreview-1846590962). - **Needed by:** GH-16246 - **Addressed to:** @baskaryan, @cbornet ## Details - Add `cursor` fetch mode - Accept SQL query parameters - Accept both `str` and SQLAlchemy selectables as query expression - Expose `execution_options` - Documentation page (notebook) about `SQLDatabase` [^1] See [About SQLDatabase](https://github.com/langchain-ai/langchain/blob/c1c7b763/docs/docs/integrations/tools/sql_database.ipynb). [^1]: Apparently there hasn't been any yet? --------- Co-authored-by: Andreas Motl <andreas.motl@crate.io>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||
|
||||
from sqlalchemy import Result
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
|
||||
@@ -42,7 +44,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
return self.db.run_no_throw(query)
|
||||
|
||||
|
@@ -1,12 +1,21 @@
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.utils import get_from_env
|
||||
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||
from sqlalchemy import (
|
||||
Executable,
|
||||
MetaData,
|
||||
Result,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
@@ -373,67 +382,113 @@ class SQLDatabase:
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
command: str,
|
||||
fetch: Literal["all", "one"] = "all",
|
||||
) -> Sequence[Dict[str, Any]]:
|
||||
command: Union[str, Executable],
|
||||
fetch: Literal["all", "one", "cursor"] = "all",
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[Sequence[Dict[str, Any]], Result]:
|
||||
"""
|
||||
Executes SQL command through underlying engine.
|
||||
|
||||
If the statement returns no rows, an empty list is returned.
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
execution_options = execution_options or {}
|
||||
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
|
||||
if self._schema is not None:
|
||||
if self.dialect == "snowflake":
|
||||
connection.exec_driver_sql(
|
||||
"ALTER SESSION SET search_path = %s", (self._schema,)
|
||||
"ALTER SESSION SET search_path = %s",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "bigquery":
|
||||
connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,))
|
||||
connection.exec_driver_sql(
|
||||
"SET @@dataset_id=?",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "mssql":
|
||||
pass
|
||||
elif self.dialect == "trino":
|
||||
connection.exec_driver_sql("USE ?", (self._schema,))
|
||||
connection.exec_driver_sql(
|
||||
"USE ?",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "duckdb":
|
||||
# Unclear which parameterized argument syntax duckdb supports.
|
||||
# The docs for the duckdb client say they support multiple,
|
||||
# but `duckdb_engine` seemed to struggle with all of them:
|
||||
# https://github.com/Mause/duckdb_engine/issues/796
|
||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||
connection.exec_driver_sql(
|
||||
f"SET search_path TO {self._schema}",
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "oracle":
|
||||
connection.exec_driver_sql(
|
||||
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}"
|
||||
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
|
||||
execution_options=execution_options,
|
||||
)
|
||||
elif self.dialect == "sqlany":
|
||||
# If anybody using Sybase SQL anywhere database then it should not
|
||||
# go to else condition. It should be same as mssql.
|
||||
pass
|
||||
elif self.dialect == "postgresql": # postgresql
|
||||
connection.exec_driver_sql("SET search_path TO %s", (self._schema,))
|
||||
connection.exec_driver_sql(
|
||||
"SET search_path TO %s",
|
||||
(self._schema,),
|
||||
execution_options=execution_options,
|
||||
)
|
||||
|
||||
if isinstance(command, str):
|
||||
command = text(command)
|
||||
elif isinstance(command, Executable):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"Query expression has unknown type: {type(command)}")
|
||||
cursor = connection.execute(
|
||||
command,
|
||||
parameters,
|
||||
execution_options=execution_options,
|
||||
)
|
||||
|
||||
cursor = connection.execute(text(command))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = [x._asdict() for x in cursor.fetchall()]
|
||||
elif fetch == "one":
|
||||
first_result = cursor.fetchone()
|
||||
result = [] if first_result is None else [first_result._asdict()]
|
||||
elif fetch == "cursor":
|
||||
return cursor
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
raise ValueError(
|
||||
"Fetch parameter must be either 'one', 'all', or 'cursor'"
|
||||
)
|
||||
return result
|
||||
return []
|
||||
|
||||
def run(
|
||||
self,
|
||||
command: str,
|
||||
fetch: Literal["all", "one"] = "all",
|
||||
command: Union[str, Executable],
|
||||
fetch: Literal["all", "one", "cursor"] = "all",
|
||||
include_columns: bool = False,
|
||||
) -> str:
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
"""
|
||||
result = self._execute(command, fetch)
|
||||
result = self._execute(
|
||||
command, fetch, parameters=parameters, execution_options=execution_options
|
||||
)
|
||||
|
||||
if fetch == "cursor":
|
||||
return result
|
||||
|
||||
res = [
|
||||
{
|
||||
@@ -472,7 +527,10 @@ class SQLDatabase:
|
||||
command: str,
|
||||
fetch: Literal["all", "one"] = "all",
|
||||
include_columns: bool = False,
|
||||
) -> str:
|
||||
*,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
execution_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
@@ -481,7 +539,13 @@ class SQLDatabase:
|
||||
If the statement throws an error, the error message is returned.
|
||||
"""
|
||||
try:
|
||||
return self.run(command, fetch, include_columns)
|
||||
return self.run(
|
||||
command,
|
||||
fetch,
|
||||
parameters=parameters,
|
||||
execution_options=execution_options,
|
||||
include_columns=include_columns,
|
||||
)
|
||||
except SQLAlchemyError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
@@ -173,7 +173,7 @@ class SQLDatabaseChain(Chain):
|
||||
sql_cmd = checked_sql_command
|
||||
|
||||
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
|
||||
_run_manager.on_text(str(result), color="yellow", verbose=self.verbose)
|
||||
# If return direct, we just set the final result equal to
|
||||
# the result of the sql query result, otherwise try to get a human readable
|
||||
# final answer
|
||||
|
@@ -78,6 +78,7 @@ class VectorSQLRetrieveAllOutputParser(VectorSQLOutputParser):
|
||||
|
||||
def get_result_from_sqldb(db: SQLDatabase, cmd: str) -> Sequence[Dict[str, Any]]:
|
||||
result = db._execute(cmd, fetch="all")
|
||||
assert isinstance(result, Sequence)
|
||||
return result
|
||||
|
||||
|
||||
|
@@ -1,16 +1,19 @@
|
||||
# flake8: noqa=E501
|
||||
# flake8: noqa: E501
|
||||
"""Test SQL database wrapper."""
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
from langchain_community.utilities.sql_database import SQLDatabase, truncate_word
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
Result,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
create_engine,
|
||||
insert,
|
||||
select,
|
||||
)
|
||||
|
||||
metadata_obj = MetaData()
|
||||
@@ -108,8 +111,8 @@ def test_table_info_w_sample_rows() -> None:
|
||||
assert sorted(output.split()) == sorted(expected_output.split())
|
||||
|
||||
|
||||
def test_sql_database_run() -> None:
|
||||
"""Test that commands can be run successfully and returned in correct format."""
|
||||
def test_sql_database_run_fetch_all() -> None:
|
||||
"""Verify running SQL expressions returning results as strings."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(
|
||||
@@ -131,6 +134,52 @@ def test_sql_database_run() -> None:
|
||||
assert full_output == expected_full_output
|
||||
|
||||
|
||||
def test_sql_database_run_fetch_result() -> None:
|
||||
"""Verify running SQL expressions returning results as SQLAlchemy `Result` instances."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=17, user_name="hwchase")
|
||||
with engine.begin() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
command = "select user_id, user_name, user_bio from user where user_id = 17"
|
||||
|
||||
result = db.run(command, fetch="cursor", include_columns=True)
|
||||
expected = [{"user_id": 17, "user_name": "hwchase", "user_bio": None}]
|
||||
assert isinstance(result, Result)
|
||||
assert result.mappings().fetchall() == expected
|
||||
|
||||
|
||||
def test_sql_database_run_with_parameters() -> None:
|
||||
"""Verify running SQL expressions with query parameters."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=17, user_name="hwchase")
|
||||
with engine.begin() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
command = "select user_id, user_name, user_bio from user where user_id = :user_id"
|
||||
|
||||
full_output = db.run(command, parameters={"user_id": 17}, include_columns=True)
|
||||
expected_full_output = "[{'user_id': 17, 'user_name': 'hwchase', 'user_bio': None}]"
|
||||
assert full_output == expected_full_output
|
||||
|
||||
|
||||
def test_sql_database_run_sqlalchemy_selectable() -> None:
|
||||
"""Verify running SQL expressions using SQLAlchemy selectable."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj.create_all(engine)
|
||||
stmt = insert(user).values(user_id=17, user_name="hwchase")
|
||||
with engine.begin() as conn:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
command = select(user).where(user.c.user_id == 17)
|
||||
|
||||
full_output = db.run(command, include_columns=True)
|
||||
expected_full_output = "[{'user_id': 17, 'user_name': 'hwchase', 'user_bio': None}]"
|
||||
assert full_output == expected_full_output
|
||||
|
||||
|
||||
def test_sql_database_run_update() -> None:
|
||||
"""Test commands which return no rows return an empty string."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
@@ -145,6 +194,24 @@ def test_sql_database_run_update() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_sql_database_schema_translate_map() -> None:
|
||||
"""Verify using statement-specific execution options."""
|
||||
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
db = SQLDatabase(engine)
|
||||
|
||||
# Define query using SQLAlchemy selectable.
|
||||
command = select(user).where(user.c.user_id == 17)
|
||||
|
||||
# Define statement-specific execution options.
|
||||
execution_options = {"schema_translate_map": {None: "bar"}}
|
||||
|
||||
# Verify the schema translation is applied.
|
||||
with pytest.raises(sa.exc.OperationalError) as ex:
|
||||
db.run(command, execution_options=execution_options, fetch="cursor")
|
||||
assert ex.match("no such table: bar.user")
|
||||
|
||||
|
||||
def test_truncate_word() -> None:
|
||||
assert truncate_word("Hello World", length=5) == "He..."
|
||||
assert truncate_word("Hello World", length=0) == "Hello World"
|
||||
|
Reference in New Issue
Block a user