diff --git a/langchain/agents/agent_toolkits/sql/toolkit.py b/langchain/agents/agent_toolkits/sql/toolkit.py index f9589a601e1..978c4d0a596 100644 --- a/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/langchain/agents/agent_toolkits/sql/toolkit.py @@ -33,27 +33,39 @@ class SQLDatabaseToolkit(BaseToolkit): def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" + list_sql_database_tool = ListSQLDatabaseTool(db=self.db) + info_sql_database_tool_description = ( + "Input to this tool is a comma-separated list of tables, output is the " + "schema and sample rows for those tables. " + "Be sure that the tables actually exist by calling " + f"{list_sql_database_tool.name} first! " + "Example Input: 'table1, table2, table3'" + ) + info_sql_database_tool = InfoSQLDatabaseTool( + db=self.db, description=info_sql_database_tool_description + ) query_sql_database_tool_description = ( "Input to this tool is a detailed and correct SQL query, output is a " "result from the database. If the query is not correct, an error message " "will be returned. If an error is returned, rewrite the query, check the " "query, and try again. If you encounter an issue with Unknown column " - "'xxxx' in 'field list', using schema_sql_db to query the correct table " - "fields." + f"'xxxx' in 'field list', using {info_sql_database_tool.name} " + "to query the correct table fields." ) - info_sql_database_tool_description = ( - "Input to this tool is a comma-separated list of tables, output is the " - "schema and sample rows for those tables. " - "Be sure that the tables actually exist by calling list_tables_sql_db " - "first! Example Input: 'table1, table2, table3'" + query_sql_database_tool = QuerySQLDataBaseTool( + db=self.db, description=query_sql_database_tool_description + ) + query_sql_checker_tool_description = ( + "Use this tool to double check if your query is correct before executing " + "it. Always use this tool before executing a query with " + f"{query_sql_database_tool.name}!" + ) + query_sql_checker_tool = QuerySQLCheckerTool( + db=self.db, llm=self.llm, description=query_sql_checker_tool_description ) return [ - QuerySQLDataBaseTool( - db=self.db, description=query_sql_database_tool_description - ), - InfoSQLDatabaseTool( - db=self.db, description=info_sql_database_tool_description - ), - ListSQLDatabaseTool(db=self.db), - QuerySQLCheckerTool(db=self.db, llm=self.llm), + query_sql_database_tool, + info_sql_database_tool, + list_sql_database_tool, + query_sql_checker_tool, ]