diff --git a/libs/community/langchain_community/agent_toolkits/sql/base.py b/libs/community/langchain_community/agent_toolkits/sql/base.py index a4d638b32c2..f9b6e400630 100644 --- a/libs/community/langchain_community/agent_toolkits/sql/base.py +++ b/libs/community/langchain_community/agent_toolkits/sql/base.py @@ -150,17 +150,18 @@ def create_sql_agent( prompt = prompt.partial(top_k=str(top_k)) if "dialect" in prompt.input_variables: prompt = prompt.partial(dialect=toolkit.dialect) - db_context = toolkit.get_context() - if "table_info" in prompt.input_variables: - prompt = prompt.partial(table_info=db_context["table_info"]) - tools = [ - tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool) - ] - if "table_names" in prompt.input_variables: - prompt = prompt.partial(table_names=db_context["table_names"]) - tools = [ - tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool) - ] + if any(key in prompt.input_variables for key in ["table_info", "table_names"]): + db_context = toolkit.get_context() + if "table_info" in prompt.input_variables: + prompt = prompt.partial(table_info=db_context["table_info"]) + tools = [ + tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool) + ] + if "table_names" in prompt.input_variables: + prompt = prompt.partial(table_names=db_context["table_names"]) + tools = [ + tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool) + ] if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: if prompt is None: