Compare commits

...

2 Commits

Author SHA1 Message Date
Francisco Ingham
69e29a3fed modifying suffix to account for changes 2023-09-08 18:15:45 -03:00
Francisco Ingham
deb89e5721 not checking query before running, adding table names automatically, adding retriever results automatically 2023-09-08 18:07:21 -03:00
3 changed files with 97 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
"""SQL agent.""" """SQL agent."""
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import ( from langchain.agents.agent_toolkits.sql.prompt import (
@@ -11,19 +11,90 @@ from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import (
OpenAIFunctionsAgent,
_format_intermediate_steps,
_parse_ai_message,
)
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ( from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate, ChatPromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
MessagesPlaceholder, MessagesPlaceholder,
) )
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, SystemMessage from langchain.schema.messages import AIMessage, SystemMessage
from langchain.schema import AgentAction, AgentFinish, BasePromptTemplate
from langchain.tools import BaseTool from langchain.tools import BaseTool
class SQLOpenAIFunctionsAgent(OpenAIFunctionsAgent):
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
automatic_retrievers: Sequence[tuple[str, Callable]]
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
with_functions: bool = True,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k]
for k in self.prompt.input_variables
if k != "agent_scratchpad" and k != "retrieved_results"
}
# Add the automatic retrievers' results to the AI message
retrieved_results = ""
for intro_message, retriever in self.automatic_retrievers:
result = retriever(selected_inputs["input"])
retrieved_results += intro_message + "\n" + str(result)
retrieved_results += "\n\n"
full_inputs = dict(
**selected_inputs,
agent_scratchpad=agent_scratchpad,
retrieved_results=retrieved_results,
)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
print(messages)
if with_functions:
predicted_message = self.llm.predict_messages(
messages,
functions=self.functions,
callbacks=callbacks,
)
else:
predicted_message = self.llm.predict_messages(
messages,
callbacks=callbacks,
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
def create_sql_agent( def create_sql_agent(
llm: BaseLanguageModel, llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit, toolkit: SQLDatabaseToolkit,
@@ -40,6 +111,7 @@ def create_sql_agent(
verbose: bool = False, verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None, agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (), extra_tools: Sequence[BaseTool] = (),
automatic_retrievers: Sequence[tuple[str, Callable]] = (),
**kwargs: Dict[str, Any], **kwargs: Dict[str, Any],
) -> AgentExecutor: ) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools.""" """Construct an SQL agent from an LLM and tools."""
@@ -47,6 +119,11 @@ def create_sql_agent(
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent agent: BaseSingleActionAgent
prefix += "\n These are the tables in the database: "
prefix += ", ".join(toolkit.db.get_usable_table_names())
print(prefix)
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt( prompt = ZeroShotAgent.create_prompt(
tools, tools,
@@ -55,6 +132,7 @@ def create_sql_agent(
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=input_variables, input_variables=input_variables,
) )
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
@@ -71,13 +149,21 @@ def create_sql_agent(
MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name="agent_scratchpad"),
] ]
input_variables = ["input", "agent_scratchpad"] input_variables = ["input", "agent_scratchpad"]
if automatic_retrievers:
messages.insert(
1, AIMessagePromptTemplate.from_template("{retrieved_results}")
)
input_variables.append("retrieved_results")
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages) _prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent( agent = SQLOpenAIFunctionsAgent(
llm=llm, llm=llm,
prompt=_prompt, prompt=_prompt,
tools=tools, tools=tools,
callback_manager=callback_manager, callback_manager=callback_manager,
automatic_retrievers=automatic_retrievers,
**kwargs, **kwargs,
) )
else: else:

View File

@@ -20,4 +20,4 @@ Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables. Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}""" {agent_scratchpad}"""
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.""" SQL_FUNCTIONS_SUFFIX = """I should query the schema of the most relevant tables."""

View File

@@ -4,7 +4,6 @@ from typing import List
from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools.sql_database.tool import ( from langchain.tools.sql_database.tool import (
InfoSQLDatabaseTool, InfoSQLDatabaseTool,
ListSQLDatabaseTool, ListSQLDatabaseTool,
@@ -12,6 +11,9 @@ from langchain.tools.sql_database.tool import (
QuerySQLDataBaseTool, QuerySQLDataBaseTool,
) )
from langchain.utilities.sql_database import SQLDatabase from langchain.utilities.sql_database import SQLDatabase
from pydantic import Field
from langchain.tools import BaseTool
class SQLDatabaseToolkit(BaseToolkit): class SQLDatabaseToolkit(BaseToolkit):
@@ -34,10 +36,8 @@ class SQLDatabaseToolkit(BaseToolkit):
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db) list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = ( info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the " "Input to this tool is a comma-separated list of tables (with spaces), output is the "
"schema and sample rows for those tables. " "schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: 'table1, table2, table3'" "Example Input: 'table1, table2, table3'"
) )
info_sql_database_tool = InfoSQLDatabaseTool( info_sql_database_tool = InfoSQLDatabaseTool(
@@ -55,9 +55,8 @@ class SQLDatabaseToolkit(BaseToolkit):
db=self.db, description=query_sql_database_tool_description db=self.db, description=query_sql_database_tool_description
) )
query_sql_checker_tool_description = ( query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing " "Use this tool to double check if your query is correct if"
"it. Always use this tool before executing a query with " "recovering from an error"
f"{query_sql_database_tool.name}!"
) )
query_sql_checker_tool = QuerySQLCheckerTool( query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description db=self.db, llm=self.llm, description=query_sql_checker_tool_description
@@ -65,6 +64,6 @@ class SQLDatabaseToolkit(BaseToolkit):
return [ return [
query_sql_database_tool, query_sql_database_tool,
info_sql_database_tool, info_sql_database_tool,
list_sql_database_tool, # list_sql_database_tool,
query_sql_checker_tool, query_sql_checker_tool,
] ]