mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
SQLDatabase drop the column names in the result. (#15361)
Fix for the following bug: https://github.com/langchain-ai/langchain/issues/15360 --------- Co-authored-by: dudu butbul <100126964+dudu-upstream@users.noreply.github.com>
This commit is contained in:
parent
07d294b5ec
commit
7e6b0056b8
@ -427,6 +427,7 @@ class SQLDatabase:
|
|||||||
self,
|
self,
|
||||||
command: str,
|
command: str,
|
||||||
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
||||||
|
include_columns: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Execute a SQL command and return a string representing the results.
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
@ -434,12 +435,18 @@ class SQLDatabase:
|
|||||||
If the statement returns no rows, an empty string is returned.
|
If the statement returns no rows, an empty string is returned.
|
||||||
"""
|
"""
|
||||||
result = self._execute(command, fetch)
|
result = self._execute(command, fetch)
|
||||||
# Convert columns values to string to avoid issues with sqlalchemy
|
|
||||||
# truncating text
|
|
||||||
res = [
|
res = [
|
||||||
tuple(truncate_word(c, length=self._max_string_length) for c in r.values())
|
{
|
||||||
|
column: truncate_word(value, length=self._max_string_length)
|
||||||
|
for column, value in r.items()
|
||||||
|
}
|
||||||
for r in result
|
for r in result
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not include_columns:
|
||||||
|
res = [tuple(row.values()) for row in res]
|
||||||
|
|
||||||
if not res:
|
if not res:
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
@ -465,6 +472,7 @@ class SQLDatabase:
|
|||||||
self,
|
self,
|
||||||
command: str,
|
command: str,
|
||||||
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
||||||
|
include_columns: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Execute a SQL command and return a string representing the results.
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
@ -474,7 +482,7 @@ class SQLDatabase:
|
|||||||
If the statement throws an error, the error message is returned.
|
If the statement throws an error, the error message is returned.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self.run(command, fetch)
|
return self.run(command, fetch, include_columns)
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
"""Format the error message"""
|
"""Format the error message"""
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
@ -120,10 +120,16 @@ def test_sql_database_run() -> None:
|
|||||||
conn.execute(stmt)
|
conn.execute(stmt)
|
||||||
db = SQLDatabase(engine)
|
db = SQLDatabase(engine)
|
||||||
command = "select user_id, user_name, user_bio from user where user_id = 13"
|
command = "select user_id, user_name, user_bio from user where user_id = 13"
|
||||||
output = db.run(command)
|
partial_output = db.run(command)
|
||||||
user_bio = "That is my Bio " * 19 + "That is my..."
|
user_bio = "That is my Bio " * 19 + "That is my..."
|
||||||
expected_output = f"[(13, 'Harrison', '{user_bio}')]"
|
expected_partial_output = f"[(13, 'Harrison', '{user_bio}')]"
|
||||||
assert output == expected_output
|
assert partial_output == expected_partial_output
|
||||||
|
|
||||||
|
full_output = db.run(command, include_columns=True)
|
||||||
|
expected_full_output = (
|
||||||
|
"[{'user_id': 13, 'user_name': 'Harrison', 'user_bio': '%s'}]" % user_bio
|
||||||
|
)
|
||||||
|
assert full_output == expected_full_output
|
||||||
|
|
||||||
|
|
||||||
def test_sql_database_run_update() -> None:
|
def test_sql_database_run_update() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user