mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 10:39:23 +00:00
Improve type hints and interface for SQL execution functionality. (#11353)
The previous API of the `_execute()` function had a few rough edges that this PR addresses: - The `fetch` argument was type-hinted as being able to take any string, but any string other than `"all"` or `"one"` would `raise ValueError`. The new type hints explicitly declare that only those values are supported. - The return type was type-hinted as `Sequence` but using `fetch = "one"` would actually return a single result item. This was incorrectly suppressed using `# type: ignore`. We now always return a list. - Using `fetch = "one"` would return a single item if data was found, or an empty *list* if no data was found. This was confusing, and we now always return a list to simplify. - The return type was `Sequence[Any]` which was a bit difficult to use since it wasn't clear what one could do with the returned rows. I'm making the new type `Dict[str, Any]` that corresponds to the column names and their values in the query. I've updated the use of this method elsewhere in the file to match the new behavior.
This commit is contained in:
parent
3bddd708f7
commit
42d979efdd
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Optional, Sequence
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||
@ -374,7 +374,11 @@ class SQLDatabase:
|
||||
f"{sample_rows_str}"
|
||||
)
|
||||
|
||||
def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence:
|
||||
def _execute(
|
||||
self,
|
||||
command: str,
|
||||
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
||||
) -> Sequence[Dict[str, Any]]:
|
||||
"""
|
||||
Executes SQL command through underlying engine.
|
||||
|
||||
@ -397,15 +401,20 @@ class SQLDatabase:
|
||||
cursor = connection.execute(text(command))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
result = [x._asdict() for x in cursor.fetchall()]
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone() # type: ignore
|
||||
first_result = cursor.fetchone()
|
||||
result = [] if first_result is None else [first_result._asdict()]
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
return result
|
||||
return []
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
def run(
|
||||
self,
|
||||
command: str,
|
||||
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
||||
) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
@ -414,18 +423,14 @@ class SQLDatabase:
|
||||
result = self._execute(command, fetch)
|
||||
# Convert columns values to string to avoid issues with sqlalchemy
|
||||
# truncating text
|
||||
if not result:
|
||||
res = [
|
||||
tuple(truncate_word(c, length=self._max_string_length) for c in r.values())
|
||||
for r in result
|
||||
]
|
||||
if not res:
|
||||
return ""
|
||||
elif isinstance(result, list):
|
||||
res: Sequence = [
|
||||
tuple(truncate_word(c, length=self._max_string_length) for c in r)
|
||||
for r in result
|
||||
]
|
||||
else:
|
||||
res = tuple(
|
||||
truncate_word(c, length=self._max_string_length) for c in result
|
||||
)
|
||||
return str(res)
|
||||
return str(res)
|
||||
|
||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables.
|
||||
@ -443,7 +448,11 @@ class SQLDatabase:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def run_no_throw(self, command: str, fetch: str = "all") -> str:
|
||||
def run_no_throw(
|
||||
self,
|
||||
command: str,
|
||||
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
||||
) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
|
Loading…
Reference in New Issue
Block a user