not checking query before running, adding table names automatically, adding retriever results automatically

This commit is contained in:
Francisco Ingham
2023-09-08 18:07:21 -03:00
parent 5d8a689d5e
commit deb89e5721
3 changed files with 99 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
"""SQL agent."""
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import (
@@ -11,19 +11,90 @@ from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import (
OpenAIFunctionsAgent,
_format_intermediate_steps,
_parse_ai_message,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, SystemMessage
from langchain.schema import AgentAction, AgentFinish, BasePromptTemplate
from langchain.tools import BaseTool
class SQLOpenAIFunctionsAgent(OpenAIFunctionsAgent):
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
automatic_retrievers: Sequence[tuple[str, Callable]]
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
with_functions: bool = True,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k]
for k in self.prompt.input_variables
if k != "agent_scratchpad" and k != "retrieved_results"
}
# Add the automatic retrievers' results to the AI message
retrieved_results = ""
for intro_message, retriever in self.automatic_retrievers:
result = retriever(selected_inputs["input"])
retrieved_results += intro_message + "\n" + str(result)
retrieved_results += "\n\n"
full_inputs = dict(
**selected_inputs,
agent_scratchpad=agent_scratchpad,
retrieved_results=retrieved_results,
)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
print(messages)
if with_functions:
predicted_message = self.llm.predict_messages(
messages,
functions=self.functions,
callbacks=callbacks,
)
else:
predicted_message = self.llm.predict_messages(
messages,
callbacks=callbacks,
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
@@ -40,6 +111,7 @@ def create_sql_agent(
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
automatic_retrievers: Sequence[tuple[str, Callable]] = (),
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
@@ -47,6 +119,11 @@ def create_sql_agent(
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent
prefix += "\n These are the tables in the database: "
prefix += ", ".join(toolkit.db.get_usable_table_names())
print(prefix)
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(
tools,
@@ -55,6 +132,7 @@ def create_sql_agent(
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
@@ -71,13 +149,21 @@ def create_sql_agent(
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
if automatic_retrievers:
messages.insert(
1, AIMessagePromptTemplate.from_template("{retrieved_results}")
)
input_variables.append("retrieved_results")
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
agent = SQLOpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
automatic_retrievers=automatic_retrievers,
**kwargs,
)
else:

View File

@@ -20,4 +20,6 @@ Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
# SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
SQL_FUNCTIONS_SUFFIX = ""

View File

@@ -4,7 +4,6 @@ from typing import List
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.pydantic_v1 import Field
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
@@ -12,6 +11,9 @@ from langchain.tools.sql_database.tool import (
QuerySQLDataBaseTool,
)
from langchain.utilities.sql_database import SQLDatabase
from pydantic import Field
from langchain.tools import BaseTool
class SQLDatabaseToolkit(BaseToolkit):
@@ -34,10 +36,8 @@ class SQLDatabaseToolkit(BaseToolkit):
"""Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"Input to this tool is a comma-separated list of tables (with spaces), output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: 'table1, table2, table3'"
)
info_sql_database_tool = InfoSQLDatabaseTool(
@@ -55,9 +55,8 @@ class SQLDatabaseToolkit(BaseToolkit):
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
"Use this tool to double check if your query is correct if"
"recovering from an error"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
@@ -65,6 +64,6 @@ class SQLDatabaseToolkit(BaseToolkit):
return [
query_sql_database_tool,
info_sql_database_tool,
list_sql_database_tool,
# list_sql_database_tool,
query_sql_checker_tool,
]