diff --git a/libs/langchain/langchain/tools/sql_database/prompt.py b/libs/langchain/langchain/tools/sql_database/prompt.py index 8d2b097358f..34ab0fd3b16 100644 --- a/libs/langchain/langchain/tools/sql_database/prompt.py +++ b/libs/langchain/langchain/tools/sql_database/prompt.py @@ -11,4 +11,8 @@ Double check the {dialect} query above for common mistakes, including: - Casting to the correct data type - Using the proper columns for joins -If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""" +If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. + +Output the final SQL query only. + +SQL Query: """ diff --git a/libs/langchain/langchain/utilities/sql_database.py b/libs/langchain/langchain/utilities/sql_database.py index e0dd8744acd..110f081d3c0 100644 --- a/libs/langchain/langchain/utilities/sql_database.py +++ b/libs/langchain/langchain/utilities/sql_database.py @@ -2,7 +2,7 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, List, Optional +from typing import Any, Iterable, List, Optional, Sequence import sqlalchemy from sqlalchemy import MetaData, Table, create_engine, inspect, select, text @@ -368,12 +368,11 @@ class SQLDatabase: f"{sample_rows_str}" ) - def run(self, command: str, fetch: str = "all") -> str: - """Execute a SQL command and return a string representing the results. - - If the statement returns rows, a string of the results is returned. - If the statement returns no rows, an empty string is returned. + def _execute(self, command: str, fetch: Optional[str] = "all") -> Sequence: + """ + Executes SQL command through underlying engine. + If the statement returns no rows, an empty list is returned. """ with self._engine.begin() as connection: if self._schema is not None: @@ -395,26 +394,30 @@ class SQLDatabase: result = cursor.fetchone() # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") + return result + return [] - # Convert columns values to string to avoid issues with sqlalchmey - # trunacating text - if isinstance(result, list): - return str( - [ - tuple( - truncate_word(c, length=self._max_string_length) - for c in r - ) - for r in result - ] - ) + def run(self, command: str, fetch: str = "all") -> str: + """Execute a SQL command and return a string representing the results. - return str( - tuple( - truncate_word(c, length=self._max_string_length) for c in result - ) - ) - return "" + If the statement returns rows, a string of the results is returned. + If the statement returns no rows, an empty string is returned. + """ + result = self._execute(command, fetch) + # Convert columns values to string to avoid issues with sqlalchemy + # truncating text + if not result: + return "" + elif isinstance(result, list): + res: Sequence = [ + tuple(truncate_word(c, length=self._max_string_length) for c in r) + for r in result + ] + else: + res = tuple( + truncate_word(c, length=self._max_string_length) for c in result + ) + return str(res) def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables.