diff --git a/libs/community/langchain_community/tools/sql_database/tool.py b/libs/community/langchain_community/tools/sql_database/tool.py index 9d4f90b7245..62e3349365d 100644 --- a/libs/community/langchain_community/tools/sql_database/tool.py +++ b/libs/community/langchain_community/tools/sql_database/tool.py @@ -39,6 +39,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. """ + args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput def _run( self, @@ -77,11 +78,16 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): ) +class _ListSQLDataBaseToolInput(BaseModel): + tool_input: str = Field(..., description="An empty string") + + class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): """Tool for getting tables names.""" name: str = "sql_db_list_tables" description: str = "Input is an empty string, output is a comma separated list of tables in the database." + args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput def _run( self, @@ -92,6 +98,10 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): return ", ".join(self.db.get_usable_table_names()) +class _QuerySQLCheckerToolInput(BaseModel): + query: str = Field(..., description="A detailed and SQL query to be checked.") + + class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): """Use an LLM to check if a query is correct. Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" @@ -104,6 +114,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query! """ + args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput @root_validator(pre=True) def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: