mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
Pass kwargs from initialize_agent into agent classmethod (#799)
# Problem I noticed that in order to change the prefix of the prompt in the `zero-shot-react-description` agent we had to dig around to subset strings deep into the agent's attributes. It requires the user to inspect a long chain of attributes and classes. `initialize_agent -> AgentExecutor -> Agent -> LLMChain -> Prompt from Agent.create_prompt` ``` python agent = initialize_agent( tools=tools, llm=fake_llm, agent="zero-shot-react-description" ) prompt_str = agent.agent.llm_chain.prompt.template new_prompt_str = change_prefix(prompt_str) agent.agent.llm_chain.prompt.template = new_prompt_str ``` # Implemented Solution `initialize_agent` accepts `**kwargs` but passes it to `AgentExecutor` but not `ZeroShotAgent`, by simply giving the kwargs to the agent class methods we can support changing the prefix and suffix for one agent while allowing future agents to take advantage of `initialize_agent`. ``` agent = initialize_agent( tools=tools, llm=fake_llm, agent="zero-shot-react-description", agent_kwargs={"prefix": prefix, "suffix": suffix} ) ``` To be fair, this was before finding docs around custom agents here: https://langchain.readthedocs.io/en/latest/modules/agents/examples/custom_agent.html?highlight=custom%20#custom-llmchain but i find that my use case just needed to change the prefix a little. # Changes * Pass kwargs to Agent class method * Added a test to check suffix and prefix --------- Co-authored-by: Jason Liu <jason@jxnl.coA>
This commit is contained in:
parent
c331009440
commit
54f9e4287f
@ -14,6 +14,7 @@ def initialize_agent(
|
|||||||
agent: Optional[str] = None,
|
agent: Optional[str] = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
agent_path: Optional[str] = None,
|
agent_path: Optional[str] = None,
|
||||||
|
agent_kwargs: Optional[dict] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Load agent given tools and LLM.
|
"""Load agent given tools and LLM.
|
||||||
@ -50,8 +51,9 @@ def initialize_agent(
|
|||||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||||
)
|
)
|
||||||
agent_cls = AGENT_TO_CLASS[agent]
|
agent_cls = AGENT_TO_CLASS[agent]
|
||||||
|
agent_kwargs = agent_kwargs or {}
|
||||||
agent_obj = agent_cls.from_llm_and_tools(
|
agent_obj = agent_cls.from_llm_and_tools(
|
||||||
llm, tools, callback_manager=callback_manager
|
llm, tools, callback_manager=callback_manager, **agent_kwargs
|
||||||
)
|
)
|
||||||
elif agent_path is not None:
|
elif agent_path is not None:
|
||||||
agent_obj = load_agent(
|
agent_obj = load_agent(
|
||||||
|
@ -196,3 +196,29 @@ def test_agent_tool_return_direct() -> None:
|
|||||||
|
|
||||||
output = agent.run("when was langchain made")
|
output = agent.run("when was langchain made")
|
||||||
assert output == "misalignment"
|
assert output == "misalignment"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_new_prefix_suffix() -> None:
|
||||||
|
"""Test agent initilization kwargs with new prefix and suffix."""
|
||||||
|
fake_llm = FakeListLLM(
|
||||||
|
responses=["FooBarBaz\nAction: Search\nAction Input: misalignment"]
|
||||||
|
)
|
||||||
|
tools = [
|
||||||
|
Tool("Search", lambda x: x, "Useful for searching", return_direct=True),
|
||||||
|
]
|
||||||
|
prefix = "FooBarBaz"
|
||||||
|
|
||||||
|
suffix = "Begin now!\nInput: {input}\nThought: {agent_scratchpad}"
|
||||||
|
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools=tools,
|
||||||
|
llm=fake_llm,
|
||||||
|
agent="zero-shot-react-description",
|
||||||
|
agent_kwargs={"prefix": prefix, "suffix": suffix},
|
||||||
|
)
|
||||||
|
|
||||||
|
# avoids "BasePromptTemplate" has no attribute "template" error
|
||||||
|
assert hasattr(agent.agent.llm_chain.prompt, "template")
|
||||||
|
prompt_str = agent.agent.llm_chain.prompt.template
|
||||||
|
assert prompt_str.startswith(prefix), "Prompt does not start with prefix"
|
||||||
|
assert prompt_str.endswith(suffix), "Prompt does not end with suffix"
|
||||||
|
Loading…
Reference in New Issue
Block a user