mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
not checking query before running, adding table names automatically, adding retriever results automatically
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""SQL agent."""
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_toolkits.sql.prompt import (
|
||||
@@ -11,19 +11,90 @@ from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import (
|
||||
OpenAIFunctionsAgent,
|
||||
_format_intermediate_steps,
|
||||
_parse_ai_message,
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, SystemMessage
|
||||
|
||||
from langchain.schema import AgentAction, AgentFinish, BasePromptTemplate
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class SQLOpenAIFunctionsAgent(OpenAIFunctionsAgent):
|
||||
llm: BaseLanguageModel
|
||||
tools: Sequence[BaseTool]
|
||||
prompt: BasePromptTemplate
|
||||
automatic_retrievers: Sequence[tuple[str, Callable]]
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
with_functions: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k]
|
||||
for k in self.prompt.input_variables
|
||||
if k != "agent_scratchpad" and k != "retrieved_results"
|
||||
}
|
||||
|
||||
# Add the automatic retrievers' results to the AI message
|
||||
retrieved_results = ""
|
||||
for intro_message, retriever in self.automatic_retrievers:
|
||||
result = retriever(selected_inputs["input"])
|
||||
retrieved_results += intro_message + "\n" + str(result)
|
||||
retrieved_results += "\n\n"
|
||||
|
||||
full_inputs = dict(
|
||||
**selected_inputs,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
retrieved_results=retrieved_results,
|
||||
)
|
||||
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
|
||||
messages = prompt.to_messages()
|
||||
|
||||
print(messages)
|
||||
|
||||
if with_functions:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages,
|
||||
functions=self.functions,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
else:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
|
||||
def create_sql_agent(
|
||||
llm: BaseLanguageModel,
|
||||
toolkit: SQLDatabaseToolkit,
|
||||
@@ -40,6 +111,7 @@ def create_sql_agent(
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
automatic_retrievers: Sequence[tuple[str, Callable]] = (),
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct an SQL agent from an LLM and tools."""
|
||||
@@ -47,6 +119,11 @@ def create_sql_agent(
|
||||
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
|
||||
agent: BaseSingleActionAgent
|
||||
|
||||
prefix += "\n These are the tables in the database: "
|
||||
prefix += ", ".join(toolkit.db.get_usable_table_names())
|
||||
|
||||
print(prefix)
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
@@ -55,6 +132,7 @@ def create_sql_agent(
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
@@ -71,13 +149,21 @@ def create_sql_agent(
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
|
||||
if automatic_retrievers:
|
||||
messages.insert(
|
||||
1, AIMessagePromptTemplate.from_template("{retrieved_results}")
|
||||
)
|
||||
input_variables.append("retrieved_results")
|
||||
|
||||
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
agent = OpenAIFunctionsAgent(
|
||||
agent = SQLOpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
automatic_retrievers=automatic_retrievers,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -20,4 +20,6 @@ Question: {input}
|
||||
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
|
||||
{agent_scratchpad}"""
|
||||
|
||||
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
|
||||
# SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
|
||||
|
||||
SQL_FUNCTIONS_SUFFIX = ""
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import List
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
@@ -12,6 +11,9 @@ from langchain.tools.sql_database.tool import (
|
||||
QuerySQLDataBaseTool,
|
||||
)
|
||||
from langchain.utilities.sql_database import SQLDatabase
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class SQLDatabaseToolkit(BaseToolkit):
|
||||
@@ -34,10 +36,8 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
"""Get the tools in the toolkit."""
|
||||
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
|
||||
info_sql_database_tool_description = (
|
||||
"Input to this tool is a comma-separated list of tables, output is the "
|
||||
"Input to this tool is a comma-separated list of tables (with spaces), output is the "
|
||||
"schema and sample rows for those tables. "
|
||||
"Be sure that the tables actually exist by calling "
|
||||
f"{list_sql_database_tool.name} first! "
|
||||
"Example Input: 'table1, table2, table3'"
|
||||
)
|
||||
info_sql_database_tool = InfoSQLDatabaseTool(
|
||||
@@ -55,9 +55,8 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
db=self.db, description=query_sql_database_tool_description
|
||||
)
|
||||
query_sql_checker_tool_description = (
|
||||
"Use this tool to double check if your query is correct before executing "
|
||||
"it. Always use this tool before executing a query with "
|
||||
f"{query_sql_database_tool.name}!"
|
||||
"Use this tool to double check if your query is correct if"
|
||||
"recovering from an error"
|
||||
)
|
||||
query_sql_checker_tool = QuerySQLCheckerTool(
|
||||
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
|
||||
@@ -65,6 +64,6 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
return [
|
||||
query_sql_database_tool,
|
||||
info_sql_database_tool,
|
||||
list_sql_database_tool,
|
||||
# list_sql_database_tool,
|
||||
query_sql_checker_tool,
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user