mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community[patch]: add args_schema to sql_database tools for langGraph integration (#18595)
- **Description:** This modification adds pydantic input definition for sql_database tools. This helps for function calling capability in LangGraph. Since actions nodes will usually check for the args_schema attribute on tools, This update should make these tools compatible with it (only implemented on the InfoSQLDatabaseTool) - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** juanfe8881
This commit is contained in:
parent
c922ea36cb
commit
f5b9aedc48
@ -39,6 +39,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
If the query is not correct, an error message will be returned.
|
||||
If an error is returned, rewrite the query, check the query, and try again.
|
||||
"""
|
||||
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@ -77,11 +78,16 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
)
|
||||
|
||||
|
||||
class _ListSQLDataBaseToolInput(BaseModel):
|
||||
tool_input: str = Field(..., description="An empty string")
|
||||
|
||||
|
||||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting tables names."""
|
||||
|
||||
name: str = "sql_db_list_tables"
|
||||
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@ -92,6 +98,10 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
return ", ".join(self.db.get_usable_table_names())
|
||||
|
||||
|
||||
class _QuerySQLCheckerToolInput(BaseModel):
|
||||
query: str = Field(..., description="A detailed and SQL query to be checked.")
|
||||
|
||||
|
||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
@ -104,6 +114,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
Use this tool to double check if your query is correct before executing it.
|
||||
Always use this tool before executing a query with sql_db_query!
|
||||
"""
|
||||
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
Loading…
Reference in New Issue
Block a user