mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 11:09:07 +00:00
improve sql prompt (#4611)
Co-authored-by: Taqi Jaffri <tjaffri@docugami.com> Co-authored-by: Taqi Jaffri <tjaffri@gmail.com>
This commit is contained in:
parent
01531cb16d
commit
7d425cbf38
File diff suppressed because one or more lines are too long
@ -12,7 +12,11 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
|
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||||
|
|
||||||
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
|
||||||
class SQLDatabaseChain(Chain):
|
class SQLDatabaseChain(Chain):
|
||||||
@ -41,6 +45,11 @@ class SQLDatabaseChain(Chain):
|
|||||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||||
return_direct: bool = False
|
return_direct: bool = False
|
||||||
"""Whether or not to return the result of querying the SQL table directly."""
|
"""Whether or not to return the result of querying the SQL table directly."""
|
||||||
|
use_query_checker: bool = False
|
||||||
|
"""Whether or not the query checker tool should be used to attempt
|
||||||
|
to fix the initial SQL from the LLM."""
|
||||||
|
query_checker_prompt: Optional[BasePromptTemplate] = None
|
||||||
|
"""The prompt template that should be used by the query checker"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -81,7 +90,7 @@ class SQLDatabaseChain(Chain):
|
|||||||
if not self.return_intermediate_steps:
|
if not self.return_intermediate_steps:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
else:
|
else:
|
||||||
return [self.output_key, "intermediate_steps"]
|
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -96,36 +105,80 @@ class SQLDatabaseChain(Chain):
|
|||||||
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
||||||
llm_inputs = {
|
llm_inputs = {
|
||||||
"input": input_text,
|
"input": input_text,
|
||||||
"top_k": self.top_k,
|
"top_k": str(self.top_k),
|
||||||
"dialect": self.database.dialect,
|
"dialect": self.database.dialect,
|
||||||
"table_info": table_info,
|
"table_info": table_info,
|
||||||
"stop": ["\nSQLResult:"],
|
"stop": ["\nSQLResult:"],
|
||||||
}
|
}
|
||||||
intermediate_steps = []
|
intermediate_steps: List = []
|
||||||
sql_cmd = self.llm_chain.predict(
|
try:
|
||||||
callbacks=_run_manager.get_child(), **llm_inputs
|
intermediate_steps.append(llm_inputs) # input: sql generation
|
||||||
)
|
sql_cmd = self.llm_chain.predict(
|
||||||
intermediate_steps.append(sql_cmd)
|
callbacks=_run_manager.get_child(),
|
||||||
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
**llm_inputs,
|
||||||
result = self.database.run(sql_cmd)
|
).strip()
|
||||||
intermediate_steps.append(result)
|
if not self.use_query_checker:
|
||||||
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
||||||
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
|
intermediate_steps.append(
|
||||||
# If return direct, we just set the final result equal to the sql query
|
sql_cmd
|
||||||
if self.return_direct:
|
) # output: sql generation (no checker)
|
||||||
final_result = result
|
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
|
||||||
else:
|
result = self.database.run(sql_cmd)
|
||||||
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
|
intermediate_steps.append(str(result)) # output: sql exec
|
||||||
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
else:
|
||||||
llm_inputs["input"] = input_text
|
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
|
||||||
final_result = self.llm_chain.predict(
|
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
||||||
callbacks=_run_manager.get_child(), **llm_inputs
|
)
|
||||||
)
|
query_checker_chain = LLMChain(
|
||||||
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
llm=self.llm, prompt=query_checker_prompt
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
)
|
||||||
if self.return_intermediate_steps:
|
query_checker_inputs = {
|
||||||
chain_result["intermediate_steps"] = intermediate_steps
|
"query": sql_cmd,
|
||||||
return chain_result
|
"dialect": self.database.dialect,
|
||||||
|
}
|
||||||
|
checked_sql_command: str = query_checker_chain.predict(
|
||||||
|
callbacks=_run_manager.get_child(), **query_checker_inputs
|
||||||
|
).strip()
|
||||||
|
intermediate_steps.append(
|
||||||
|
checked_sql_command
|
||||||
|
) # output: sql generation (checker)
|
||||||
|
_run_manager.on_text(
|
||||||
|
checked_sql_command, color="green", verbose=self.verbose
|
||||||
|
)
|
||||||
|
intermediate_steps.append(
|
||||||
|
{"sql_cmd": checked_sql_command}
|
||||||
|
) # input: sql exec
|
||||||
|
result = self.database.run(checked_sql_command)
|
||||||
|
intermediate_steps.append(str(result)) # output: sql exec
|
||||||
|
sql_cmd = checked_sql_command
|
||||||
|
|
||||||
|
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
|
||||||
|
# If return direct, we just set the final result equal to
|
||||||
|
# the result of the sql query result, otherwise try to get a human readable
|
||||||
|
# final answer
|
||||||
|
if self.return_direct:
|
||||||
|
final_result = result
|
||||||
|
else:
|
||||||
|
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
|
||||||
|
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
||||||
|
llm_inputs["input"] = input_text
|
||||||
|
intermediate_steps.append(llm_inputs) # input: final answer
|
||||||
|
final_result = self.llm_chain.predict(
|
||||||
|
callbacks=_run_manager.get_child(),
|
||||||
|
**llm_inputs,
|
||||||
|
).strip()
|
||||||
|
intermediate_steps.append(final_result) # output: final answer
|
||||||
|
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||||
|
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||||
|
return chain_result
|
||||||
|
except Exception as exc:
|
||||||
|
# Append intermediate steps to exception, to aid in logging and later
|
||||||
|
# improvement of few shot prompt seeds
|
||||||
|
exc.intermediate_steps = intermediate_steps # type: ignore
|
||||||
|
raise exc
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
@ -195,7 +248,7 @@ class SQLDatabaseSequentialChain(Chain):
|
|||||||
if not self.return_intermediate_steps:
|
if not self.return_intermediate_steps:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
else:
|
else:
|
||||||
return [self.output_key, "intermediate_steps"]
|
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -209,9 +262,13 @@ class SQLDatabaseSequentialChain(Chain):
|
|||||||
"query": inputs[self.input_key],
|
"query": inputs[self.input_key],
|
||||||
"table_names": table_names,
|
"table_names": table_names,
|
||||||
}
|
}
|
||||||
table_names_to_use = self.decider_chain.predict_and_parse(
|
_lowercased_table_names = [name.lower() for name in _table_names]
|
||||||
callbacks=_run_manager.get_child(), **llm_inputs
|
table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)
|
||||||
)
|
table_names_to_use = [
|
||||||
|
name
|
||||||
|
for name in table_names_from_chain
|
||||||
|
if name.lower() in _lowercased_table_names
|
||||||
|
]
|
||||||
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
|
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
|
||||||
_run_manager.on_text(
|
_run_manager.on_text(
|
||||||
str(table_names_to_use), color="yellow", verbose=self.verbose
|
str(table_names_to_use), color="yellow", verbose=self.verbose
|
||||||
|
@ -3,6 +3,11 @@ from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
|||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_SUFFIX = """Only use the following tables:
|
||||||
|
{table_info}
|
||||||
|
|
||||||
|
Question: {input}"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
|
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
|
||||||
|
|
||||||
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
|
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
|
||||||
@ -16,17 +21,14 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the tables listed below.
|
"""
|
||||||
|
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
PROMPT = PromptTemplate(
|
PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "dialect", "top_k"],
|
input_variables=["input", "table_info", "dialect", "top_k"],
|
||||||
template=_DEFAULT_TEMPLATE,
|
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be necessary to answer this question.
|
_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be necessary to answer this question.
|
||||||
|
|
||||||
Question: {query}
|
Question: {query}
|
||||||
@ -53,14 +55,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
DUCKDB_PROMPT = PromptTemplate(
|
DUCKDB_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_duckdb_prompt,
|
template=_duckdb_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
_googlesql_prompt = """You are a GoogleSQL expert. Given an input question, first create a syntactically correct GoogleSQL query to run, then look at the results of the query and return the answer to the input question.
|
_googlesql_prompt = """You are a GoogleSQL expert. Given an input question, first create a syntactically correct GoogleSQL query to run, then look at the results of the query and return the answer to the input question.
|
||||||
@ -76,14 +75,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
GOOGLESQL_PROMPT = PromptTemplate(
|
GOOGLESQL_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_googlesql_prompt,
|
template=_googlesql_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -100,13 +96,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
MSSQL_PROMPT = PromptTemplate(
|
MSSQL_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"], template=_mssql_prompt
|
input_variables=["input", "table_info", "top_k"],
|
||||||
|
template=_mssql_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -123,14 +117,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
MYSQL_PROMPT = PromptTemplate(
|
MYSQL_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_mysql_prompt,
|
template=_mysql_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -147,14 +138,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
MARIADB_PROMPT = PromptTemplate(
|
MARIADB_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_mariadb_prompt,
|
template=_mariadb_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -171,14 +159,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
ORACLE_PROMPT = PromptTemplate(
|
ORACLE_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_oracle_prompt,
|
template=_oracle_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -195,13 +180,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
POSTGRES_PROMPT = PromptTemplate(
|
POSTGRES_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"], template=_postgres_prompt
|
input_variables=["input", "table_info", "top_k"],
|
||||||
|
template=_postgres_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -218,14 +201,11 @@ SQLQuery: SQL Query to run
|
|||||||
SQLResult: Result of the SQLQuery
|
SQLResult: Result of the SQLQuery
|
||||||
Answer: Final answer here
|
Answer: Final answer here
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
SQLITE_PROMPT = PromptTemplate(
|
SQLITE_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_sqlite_prompt,
|
template=_sqlite_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
_clickhouse_prompt = """You are a ClickHouse expert. Given an input question, first create a syntactically correct Clic query to run, then look at the results of the query and return the answer to the input question.
|
_clickhouse_prompt = """You are a ClickHouse expert. Given an input question, first create a syntactically correct Clic query to run, then look at the results of the query and return the answer to the input question.
|
||||||
@ -241,14 +221,11 @@ SQLQuery: "SQL Query to run"
|
|||||||
SQLResult: "Result of the SQLQuery"
|
SQLResult: "Result of the SQLQuery"
|
||||||
Answer: "Final answer here"
|
Answer: "Final answer here"
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
CLICKHOUSE_PROMPT = PromptTemplate(
|
CLICKHOUSE_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_clickhouse_prompt,
|
template=_clickhouse_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
_prestodb_prompt = """You are a PrestoDB expert. Given an input question, first create a syntactically correct PrestoDB query to run, then look at the results of the query and return the answer to the input question.
|
_prestodb_prompt = """You are a PrestoDB expert. Given an input question, first create a syntactically correct PrestoDB query to run, then look at the results of the query and return the answer to the input question.
|
||||||
@ -264,14 +241,11 @@ SQLQuery: "SQL Query to run"
|
|||||||
SQLResult: "Result of the SQLQuery"
|
SQLResult: "Result of the SQLQuery"
|
||||||
Answer: "Final answer here"
|
Answer: "Final answer here"
|
||||||
|
|
||||||
Only use the following tables:
|
"""
|
||||||
{table_info}
|
|
||||||
|
|
||||||
Question: {input}"""
|
|
||||||
|
|
||||||
PRESTODB_PROMPT = PromptTemplate(
|
PRESTODB_PROMPT = PromptTemplate(
|
||||||
input_variables=["input", "table_info", "top_k"],
|
input_variables=["input", "table_info", "top_k"],
|
||||||
template=_prestodb_prompt,
|
template=_prestodb_prompt + PROMPT_SUFFIX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user