mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-28 02:29:17 +00:00
improve sql prompt (#4611)
Co-authored-by: Taqi Jaffri <tjaffri@docugami.com> Co-authored-by: Taqi Jaffri <tjaffri@gmail.com>
This commit is contained in:
parent
01531cb16d
commit
7d425cbf38
File diff suppressed because one or more lines are too long
@ -12,7 +12,11 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
class SQLDatabaseChain(Chain):
|
||||
@ -41,6 +45,11 @@ class SQLDatabaseChain(Chain):
|
||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the SQL table directly."""
|
||||
use_query_checker: bool = False
|
||||
"""Whether or not the query checker tool should be used to attempt
|
||||
to fix the initial SQL from the LLM."""
|
||||
query_checker_prompt: Optional[BasePromptTemplate] = None
|
||||
"""The prompt template that should be used by the query checker"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -81,7 +90,7 @@ class SQLDatabaseChain(Chain):
|
||||
if not self.return_intermediate_steps:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -96,36 +105,80 @@ class SQLDatabaseChain(Chain):
|
||||
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
||||
llm_inputs = {
|
||||
"input": input_text,
|
||||
"top_k": self.top_k,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": table_info,
|
||||
"stop": ["\nSQLResult:"],
|
||||
}
|
||||
intermediate_steps = []
|
||||
sql_cmd = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **llm_inputs
|
||||
)
|
||||
intermediate_steps.append(sql_cmd)
|
||||
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
||||
result = self.database.run(sql_cmd)
|
||||
intermediate_steps.append(result)
|
||||
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
|
||||
# If return direct, we just set the final result equal to the sql query
|
||||
if self.return_direct:
|
||||
final_result = result
|
||||
else:
|
||||
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
|
||||
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
||||
llm_inputs["input"] = input_text
|
||||
final_result = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **llm_inputs
|
||||
)
|
||||
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result["intermediate_steps"] = intermediate_steps
|
||||
return chain_result
|
||||
intermediate_steps: List = []
|
||||
try:
|
||||
intermediate_steps.append(llm_inputs) # input: sql generation
|
||||
sql_cmd = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(),
|
||||
**llm_inputs,
|
||||
).strip()
|
||||
if not self.use_query_checker:
|
||||
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
|
||||
intermediate_steps.append(
|
||||
sql_cmd
|
||||
) # output: sql generation (no checker)
|
||||
intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec
|
||||
result = self.database.run(sql_cmd)
|
||||
intermediate_steps.append(str(result)) # output: sql exec
|
||||
else:
|
||||
query_checker_prompt = self.query_checker_prompt or PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
||||
)
|
||||
query_checker_chain = LLMChain(
|
||||
llm=self.llm, prompt=query_checker_prompt
|
||||
)
|
||||
query_checker_inputs = {
|
||||
"query": sql_cmd,
|
||||
"dialect": self.database.dialect,
|
||||
}
|
||||
checked_sql_command: str = query_checker_chain.predict(
|
||||
callbacks=_run_manager.get_child(), **query_checker_inputs
|
||||
).strip()
|
||||
intermediate_steps.append(
|
||||
checked_sql_command
|
||||
) # output: sql generation (checker)
|
||||
_run_manager.on_text(
|
||||
checked_sql_command, color="green", verbose=self.verbose
|
||||
)
|
||||
intermediate_steps.append(
|
||||
{"sql_cmd": checked_sql_command}
|
||||
) # input: sql exec
|
||||
result = self.database.run(checked_sql_command)
|
||||
intermediate_steps.append(str(result)) # output: sql exec
|
||||
sql_cmd = checked_sql_command
|
||||
|
||||
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
|
||||
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
|
||||
# If return direct, we just set the final result equal to
|
||||
# the result of the sql query result, otherwise try to get a human readable
|
||||
# final answer
|
||||
if self.return_direct:
|
||||
final_result = result
|
||||
else:
|
||||
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
|
||||
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
||||
llm_inputs["input"] = input_text
|
||||
intermediate_steps.append(llm_inputs) # input: final answer
|
||||
final_result = self.llm_chain.predict(
|
||||
callbacks=_run_manager.get_child(),
|
||||
**llm_inputs,
|
||||
).strip()
|
||||
intermediate_steps.append(final_result) # output: final answer
|
||||
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
return chain_result
|
||||
except Exception as exc:
|
||||
# Append intermediate steps to exception, to aid in logging and later
|
||||
# improvement of few shot prompt seeds
|
||||
exc.intermediate_steps = intermediate_steps # type: ignore
|
||||
raise exc
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
@ -195,7 +248,7 @@ class SQLDatabaseSequentialChain(Chain):
|
||||
if not self.return_intermediate_steps:
|
||||
return [self.output_key]
|
||||
else:
|
||||
return [self.output_key, "intermediate_steps"]
|
||||
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -209,9 +262,13 @@ class SQLDatabaseSequentialChain(Chain):
|
||||
"query": inputs[self.input_key],
|
||||
"table_names": table_names,
|
||||
}
|
||||
table_names_to_use = self.decider_chain.predict_and_parse(
|
||||
callbacks=_run_manager.get_child(), **llm_inputs
|
||||
)
|
||||
_lowercased_table_names = [name.lower() for name in _table_names]
|
||||
table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs)
|
||||
table_names_to_use = [
|
||||
name
|
||||
for name in table_names_from_chain
|
||||
if name.lower() in _lowercased_table_names
|
||||
]
|
||||
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
|
||||
_run_manager.on_text(
|
||||
str(table_names_to_use), color="yellow", verbose=self.verbose
|
||||
|
@ -3,6 +3,11 @@ from langchain.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
PROMPT_SUFFIX = """Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
|
||||
|
||||
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
|
||||
@ -16,17 +21,14 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the tables listed below.
|
||||
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "dialect", "top_k"],
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be necessary to answer this question.
|
||||
|
||||
Question: {query}
|
||||
@ -53,14 +55,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
DUCKDB_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_duckdb_prompt,
|
||||
template=_duckdb_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
_googlesql_prompt = """You are a GoogleSQL expert. Given an input question, first create a syntactically correct GoogleSQL query to run, then look at the results of the query and return the answer to the input question.
|
||||
@ -76,14 +75,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
GOOGLESQL_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_googlesql_prompt,
|
||||
template=_googlesql_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
@ -100,13 +96,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
MSSQL_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"], template=_mssql_prompt
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_mssql_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
@ -123,14 +117,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
MYSQL_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_mysql_prompt,
|
||||
template=_mysql_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
@ -147,14 +138,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
MARIADB_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_mariadb_prompt,
|
||||
template=_mariadb_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
@ -171,14 +159,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
ORACLE_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_oracle_prompt,
|
||||
template=_oracle_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
@ -195,13 +180,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
POSTGRES_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"], template=_postgres_prompt
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_postgres_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
@ -218,14 +201,11 @@ SQLQuery: SQL Query to run
|
||||
SQLResult: Result of the SQLQuery
|
||||
Answer: Final answer here
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
SQLITE_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_sqlite_prompt,
|
||||
template=_sqlite_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
_clickhouse_prompt = """You are a ClickHouse expert. Given an input question, first create a syntactically correct Clic query to run, then look at the results of the query and return the answer to the input question.
|
||||
@ -241,14 +221,11 @@ SQLQuery: "SQL Query to run"
|
||||
SQLResult: "Result of the SQLQuery"
|
||||
Answer: "Final answer here"
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
CLICKHOUSE_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_clickhouse_prompt,
|
||||
template=_clickhouse_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
_prestodb_prompt = """You are a PrestoDB expert. Given an input question, first create a syntactically correct PrestoDB query to run, then look at the results of the query and return the answer to the input question.
|
||||
@ -264,14 +241,11 @@ SQLQuery: "SQL Query to run"
|
||||
SQLResult: "Result of the SQLQuery"
|
||||
Answer: "Final answer here"
|
||||
|
||||
Only use the following tables:
|
||||
{table_info}
|
||||
|
||||
Question: {input}"""
|
||||
"""
|
||||
|
||||
PRESTODB_PROMPT = PromptTemplate(
|
||||
input_variables=["input", "table_info", "top_k"],
|
||||
template=_prestodb_prompt,
|
||||
template=_prestodb_prompt + PROMPT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user