Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
3bb461ad21 langchain[patch]: support binding oustide create_react_agent 2024-01-29 15:37:47 -08:00

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
from typing import Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import LanguageModelLike
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableBindingBase
from langchain_core.tools import BaseTool
from langchain.agents.format_scratchpad import format_log_to_str
@@ -13,7 +14,7 @@ from langchain.tools.render import render_text_description
def create_react_agent(
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: BasePromptTemplate
llm: LanguageModelLike, tools: Sequence[BaseTool], prompt: BasePromptTemplate
) -> Runnable:
"""Create an agent that uses ReAct prompting.
@@ -32,11 +33,12 @@ def create_react_agent(
.. code-block:: python
from langchain import hub
from langchain_community.llms import OpenAI
from langchain_core.messages import AIMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_react_agent
prompt = hub.pull("hwchase17/react")
model = OpenAI()
model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
tools = ...
agent = create_react_agent(model, tools, prompt)
@@ -45,7 +47,6 @@ def create_react_agent(
agent_executor.invoke({"input": "hi"})
# Use with chat history
from langchain_core.messages import AIMessage, HumanMessage
agent_executor.invoke(
{
"input": "what's my name?",
@@ -55,6 +56,11 @@ def create_react_agent(
}
)
# Binding additional stop words to llm
model_with_stop = model.bind(stop=["Question"])
agent = create_react_agent(model_with_stop, tools, prompt)
...
Prompt:
The prompt must have input keys:
@@ -100,7 +106,15 @@ def create_react_agent(
tools=render_text_description(list(tools)),
tool_names=", ".join([t.name for t in tools]),
)
llm_with_stop = llm.bind(stop=["\nObservation"])
if (
isinstance(llm, RunnableBindingBase)
and (stop := llm.kwargs.get("stop"))
and "\nObservation" not in stop
):
stop = stop + ["\nObservation"]
else:
stop = ["\nObservation"]
llm_with_stop = llm.bind(stop=stop)
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),