experimental[patch]: fix zero-shot pandas agent (#17442)

This commit is contained in:
Bagatur
2024-02-12 21:58:35 -08:00
committed by GitHub
parent 37e1275f9e
commit c0ce93236a
7 changed files with 135 additions and 28 deletions

View File

@@ -10,7 +10,7 @@ from langchain.agents.agent import (
RunnableAgent,
RunnableMultiActionAgent,
)
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import (
OpenAIFunctionsAgent,
create_openai_functions_agent,
@@ -18,7 +18,11 @@ from langchain.agents.openai_functions_agent.base import (
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import SystemMessage
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
PromptTemplate,
)
from langchain_core.tools import BaseTool
from langchain_core.utils.interactive_env import is_interactive_env
@@ -43,7 +47,6 @@ def _get_multi_prompt(
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
tools: Sequence[BaseTool] = (),
) -> BasePromptTemplate:
if suffix is not None:
suffix_to_use = suffix
@@ -53,11 +56,8 @@ def _get_multi_prompt(
suffix_to_use = SUFFIX_NO_DF
prefix = prefix if prefix is not None else MULTI_DF_PREFIX
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix_to_use,
)
template = "\n\n".join([prefix, "{tools}", FORMAT_INSTRUCTIONS, suffix_to_use])
prompt = PromptTemplate.from_template(template)
partial_prompt = prompt.partial()
if "dfs_head" in partial_prompt.input_variables:
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
@@ -74,7 +74,6 @@ def _get_single_prompt(
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
tools: Sequence[BaseTool] = (),
) -> BasePromptTemplate:
if suffix is not None:
suffix_to_use = suffix
@@ -84,11 +83,8 @@ def _get_single_prompt(
suffix_to_use = SUFFIX_NO_DF
prefix = prefix if prefix is not None else PREFIX
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix_to_use,
)
template = "\n\n".join([prefix, "{tools}", FORMAT_INSTRUCTIONS, suffix_to_use])
prompt = PromptTemplate.from_template(template)
partial_prompt = prompt.partial()
if "df_head" in partial_prompt.input_variables:
@@ -257,7 +253,6 @@ def create_pandas_dataframe_agent(
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
tools=tools,
)
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] = RunnableAgent(
runnable=create_react_agent(llm, tools, prompt), # type: ignore