diff --git a/libs/langchain/langchain/chains/sql_database/query.py b/libs/langchain/langchain/chains/sql_database/query.py index bf5feef483b..d424adb4c5d 100644 --- a/libs/langchain/langchain/chains/sql_database/query.py +++ b/libs/langchain/langchain/chains/sql_database/query.py @@ -116,7 +116,9 @@ def create_sql_query_chain( prompt_to_use = SQL_PROMPTS[db.dialect] else: prompt_to_use = PROMPT - if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables): + if {"input", "top_k", "table_info"}.difference( + prompt_to_use.input_variables + list(prompt_to_use.partial_variables) + ): raise ValueError( f"Prompt must have input variables: 'input', 'top_k', " f"'table_info'. Received prompt with input variables: "