diff --git a/libs/community/langchain_community/agent_toolkits/sql/base.py b/libs/community/langchain_community/agent_toolkits/sql/base.py index 42953fb0a2e..73915732cf1 100644 --- a/libs/community/langchain_community/agent_toolkits/sql/base.py +++ b/libs/community/langchain_community/agent_toolkits/sql/base.py @@ -2,7 +2,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union from langchain_core.messages import AIMessage, SystemMessage from langchain_core.prompts import BasePromptTemplate, PromptTemplate @@ -40,6 +40,7 @@ def create_sql_agent( prefix: Optional[str] = None, suffix: Optional[str] = None, format_instructions: Optional[str] = None, + input_variables: Optional[List[str]] = None, top_k: int = 10, max_iterations: Optional[int] = 15, max_execution_time: Optional[float] = None, @@ -69,6 +70,7 @@ def create_sql_agent( format_instructions: Formatting instructions to pass to ZeroShotAgent.create_prompt() when 'agent_type' is "zero-shot-react-description". Otherwise ignored. + input_variables: DEPRECATED. top_k: Number of rows to query for by default. max_iterations: Passed to AgentExecutor init. max_execution_time: Passed to AgentExecutor init. @@ -119,6 +121,9 @@ def create_sql_agent( raise ValueError( "Must provide exactly one of 'toolkit' or 'db'. Received both." ) + if input_variables: + kwargs = kwargs or {} + kwargs["input_variables"] = input_variables if kwargs: warnings.warn( f"Received additional kwargs {kwargs} which are no longer supported."