mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 21:11:43 +00:00
experimental[patch]: fix zero-shot pandas agent (#17442)
This commit is contained in:
@@ -10,7 +10,7 @@ from langchain.agents.agent import (
|
||||
RunnableAgent,
|
||||
RunnableMultiActionAgent,
|
||||
)
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.agents.openai_functions_agent.base import (
|
||||
OpenAIFunctionsAgent,
|
||||
create_openai_functions_agent,
|
||||
@@ -18,7 +18,11 @@ from langchain.agents.openai_functions_agent.base import (
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
||||
from langchain_core.prompts import (
|
||||
BasePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
@@ -43,7 +47,6 @@ def _get_multi_prompt(
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
tools: Sequence[BaseTool] = (),
|
||||
) -> BasePromptTemplate:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
@@ -53,11 +56,8 @@ def _get_multi_prompt(
|
||||
suffix_to_use = SUFFIX_NO_DF
|
||||
prefix = prefix if prefix is not None else MULTI_DF_PREFIX
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix_to_use,
|
||||
)
|
||||
template = "\n\n".join([prefix, "{tools}", FORMAT_INSTRUCTIONS, suffix_to_use])
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
partial_prompt = prompt.partial()
|
||||
if "dfs_head" in partial_prompt.input_variables:
|
||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||
@@ -74,7 +74,6 @@ def _get_single_prompt(
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
tools: Sequence[BaseTool] = (),
|
||||
) -> BasePromptTemplate:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
@@ -84,11 +83,8 @@ def _get_single_prompt(
|
||||
suffix_to_use = SUFFIX_NO_DF
|
||||
prefix = prefix if prefix is not None else PREFIX
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix_to_use,
|
||||
)
|
||||
template = "\n\n".join([prefix, "{tools}", FORMAT_INSTRUCTIONS, suffix_to_use])
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
partial_prompt = prompt.partial()
|
||||
if "df_head" in partial_prompt.input_variables:
|
||||
@@ -257,7 +253,6 @@ def create_pandas_dataframe_agent(
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
tools=tools,
|
||||
)
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent] = RunnableAgent(
|
||||
runnable=create_react_agent(llm, tools, prompt), # type: ignore
|
||||
|
Reference in New Issue
Block a user