diff --git a/libs/langchain/langchain/agents/react/agent.py b/libs/langchain/langchain/agents/react/agent.py index 93709fca65b..f1aa769dc40 100644 --- a/libs/langchain/langchain/agents/react/agent.py +++ b/libs/langchain/langchain/agents/react/agent.py @@ -1,19 +1,23 @@ from __future__ import annotations -from typing import Sequence +from typing import Optional, Sequence from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool +from langchain.agents import AgentOutputParser from langchain.agents.format_scratchpad import format_log_to_str from langchain.agents.output_parsers import ReActSingleInputOutputParser from langchain.tools.render import render_text_description def create_react_agent( - llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: BasePromptTemplate + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + prompt: BasePromptTemplate, + output_parser: Optional[AgentOutputParser] = None, ) -> Runnable: """Create an agent that uses ReAct prompting. @@ -21,6 +25,7 @@ def create_react_agent( llm: LLM to use as the agent. tools: Tools this agent has access to. prompt: The prompt to use. See Prompt section below for more. + output_parser: AgentOutputParser for parse the LLM output. Returns: A Runnable sequence representing an agent. It takes as input all the same input @@ -101,12 +106,13 @@ def create_react_agent( tool_names=", ".join([t.name for t in tools]), ) llm_with_stop = llm.bind(stop=["\nObservation"]) + output_parser = output_parser or ReActSingleInputOutputParser() agent = ( RunnablePassthrough.assign( agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]), ) | prompt | llm_with_stop - | ReActSingleInputOutputParser() + | output_parser ) return agent