mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
SQLDatabase drop the column names in the result. (#15361)
Fix for the following bug: https://github.com/langchain-ai/langchain/issues/15360 --------- Co-authored-by: dudu butbul <100126964+dudu-upstream@users.noreply.github.com>
This commit is contained in:
parent
07d294b5ec
commit
7e6b0056b8
@ -427,6 +427,7 @@ class SQLDatabase:
|
||||
self,
|
||||
command: str,
|
||||
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
||||
include_columns: bool = False,
|
||||
) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
@ -434,12 +435,18 @@ class SQLDatabase:
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
"""
|
||||
result = self._execute(command, fetch)
|
||||
# Convert columns values to string to avoid issues with sqlalchemy
|
||||
# truncating text
|
||||
|
||||
res = [
|
||||
tuple(truncate_word(c, length=self._max_string_length) for c in r.values())
|
||||
{
|
||||
column: truncate_word(value, length=self._max_string_length)
|
||||
for column, value in r.items()
|
||||
}
|
||||
for r in result
|
||||
]
|
||||
|
||||
if not include_columns:
|
||||
res = [tuple(row.values()) for row in res]
|
||||
|
||||
if not res:
|
||||
return ""
|
||||
else:
|
||||
@ -465,6 +472,7 @@ class SQLDatabase:
|
||||
self,
|
||||
command: str,
|
||||
fetch: Union[Literal["all"], Literal["one"]] = "all",
|
||||
include_columns: bool = False,
|
||||
) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
@ -474,7 +482,7 @@ class SQLDatabase:
|
||||
If the statement throws an error, the error message is returned.
|
||||
"""
|
||||
try:
|
||||
return self.run(command, fetch)
|
||||
return self.run(command, fetch, include_columns)
|
||||
except SQLAlchemyError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
@ -120,10 +120,16 @@ def test_sql_database_run() -> None:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
command = "select user_id, user_name, user_bio from user where user_id = 13"
|
||||
output = db.run(command)
|
||||
partial_output = db.run(command)
|
||||
user_bio = "That is my Bio " * 19 + "That is my..."
|
||||
expected_output = f"[(13, 'Harrison', '{user_bio}')]"
|
||||
assert output == expected_output
|
||||
expected_partial_output = f"[(13, 'Harrison', '{user_bio}')]"
|
||||
assert partial_output == expected_partial_output
|
||||
|
||||
full_output = db.run(command, include_columns=True)
|
||||
expected_full_output = (
|
||||
"[{'user_id': 13, 'user_name': 'Harrison', 'user_bio': '%s'}]" % user_bio
|
||||
)
|
||||
assert full_output == expected_full_output
|
||||
|
||||
|
||||
def test_sql_database_run_update() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user