mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
update agents to use tool call messages (#20074)
```python from langchain.agents import AgentExecutor, create_tool_calling_agent, tool from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder prompt = ChatPromptTemplate.from_messages( [ ("system", "You are a helpful assistant"), MessagesPlaceholder("chat_history", optional=True), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad"), ] ) model = ChatAnthropic(model="claude-3-opus-20240229") @tool def magic_function(input: int) -> int: """Applies a magic function to an input.""" return input + 2 tools = [magic_function] agent = create_tool_calling_agent(model, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) agent_executor.invoke({"input": "what is the value of magic_function(3)?"}) ``` ``` > Entering new AgentExecutor chain... Invoking: `magic_function` with `{'input': 3}` responded: [{'text': '<thinking>\nThe user has asked for the value of magic_function applied to the input 3. Looking at the available tools, magic_function is the relevant one to use here, as it takes an integer input and returns an integer output.\n\nThe magic_function has one required parameter:\n- input (integer)\n\nThe user has directly provided the value 3 for the input parameter. Since the required parameter is present, we can proceed with calling the function.\n</thinking>', 'type': 'text'}, {'id': 'toolu_01HsTheJPA5mcipuFDBbJ1CW', 'input': {'input': 3}, 'name': 'magic_function', 'type': 'tool_use'}] 5 Therefore, the value of magic_function(3) is 5. > Finished chain. {'input': 'what is the value of magic_function(3)?', 'output': 'Therefore, the value of magic_function(3) is 5.'} ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
9eb6f538f0
commit
21c1ce0bc1
@ -126,12 +126,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"agents",
|
||||
"AgentActionMessageLog",
|
||||
),
|
||||
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
|
||||
("langchain", "schema", "agent", "ToolAgentAction"): (
|
||||
"langchain",
|
||||
"agents",
|
||||
"output_parsers",
|
||||
"openai_tools",
|
||||
"OpenAIToolAgentAction",
|
||||
"tools",
|
||||
"ToolAgentAction",
|
||||
),
|
||||
("langchain", "prompts", "chat", "BaseMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
@ -528,6 +528,13 @@ _OG_SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"image",
|
||||
"ImagePromptTemplate",
|
||||
),
|
||||
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
|
||||
"langchain",
|
||||
"agents",
|
||||
"output_parsers",
|
||||
"openai_tools",
|
||||
"OpenAIToolAgentAction",
|
||||
),
|
||||
}
|
||||
|
||||
# Needed for backwards compatibility for a few versions where we serialized
|
||||
|
@ -82,6 +82,7 @@ from langchain.agents.structured_chat.base import (
|
||||
StructuredChatAgent,
|
||||
create_structured_chat_agent,
|
||||
)
|
||||
from langchain.agents.tool_calling_agent.base import create_tool_calling_agent
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.agents.xml.base import XMLAgent, create_xml_agent
|
||||
|
||||
@ -154,4 +155,5 @@ __all__ = [
|
||||
"create_self_ask_with_search_agent",
|
||||
"create_json_chat_agent",
|
||||
"create_structured_chat_agent",
|
||||
"create_tool_calling_agent",
|
||||
]
|
||||
|
@ -11,12 +11,14 @@ from langchain.agents.format_scratchpad.openai_functions import (
|
||||
format_to_openai_function_messages,
|
||||
format_to_openai_functions,
|
||||
)
|
||||
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
||||
from langchain.agents.format_scratchpad.xml import format_xml
|
||||
|
||||
__all__ = [
|
||||
"format_xml",
|
||||
"format_to_openai_function_messages",
|
||||
"format_to_openai_functions",
|
||||
"format_to_tool_messages",
|
||||
"format_log_to_str",
|
||||
"format_log_to_messages",
|
||||
]
|
||||
|
@ -1,59 +1,5 @@
|
||||
import json
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
from langchain.agents.format_scratchpad.tools import (
|
||||
format_to_tool_messages as format_to_openai_tool_messages,
|
||||
)
|
||||
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
|
||||
|
||||
|
||||
def _create_tool_message(
|
||||
agent_action: OpenAIToolAgentAction, observation: str
|
||||
) -> ToolMessage:
|
||||
"""Convert agent action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tool invocation request from the agent
|
||||
observation: the result of the tool invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tool invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return ToolMessage(
|
||||
tool_call_id=agent_action.tool_call_id,
|
||||
content=content,
|
||||
additional_kwargs={"name": agent_action.tool},
|
||||
)
|
||||
|
||||
|
||||
def format_to_openai_tool_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
|
||||
"""
|
||||
messages = []
|
||||
for agent_action, observation in intermediate_steps:
|
||||
if isinstance(agent_action, OpenAIToolAgentAction):
|
||||
new_messages = list(agent_action.message_log) + [
|
||||
_create_tool_message(agent_action, observation)
|
||||
]
|
||||
messages.extend([new for new in new_messages if new not in messages])
|
||||
else:
|
||||
messages.append(AIMessage(content=agent_action.log))
|
||||
return messages
|
||||
__all__ = ["format_to_openai_tool_messages"]
|
||||
|
59
libs/langchain/langchain/agents/format_scratchpad/tools.py
Normal file
59
libs/langchain/langchain/agents/format_scratchpad/tools.py
Normal file
@ -0,0 +1,59 @@
|
||||
import json
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from langchain.agents.output_parsers.tools import ToolAgentAction
|
||||
|
||||
|
||||
def _create_tool_message(
|
||||
agent_action: ToolAgentAction, observation: str
|
||||
) -> ToolMessage:
|
||||
"""Convert agent action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tool invocation request from the agent
|
||||
observation: the result of the tool invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tool invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return ToolMessage(
|
||||
tool_call_id=agent_action.tool_call_id,
|
||||
content=content,
|
||||
additional_kwargs={"name": agent_action.tool},
|
||||
)
|
||||
|
||||
|
||||
def format_to_tool_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
|
||||
"""
|
||||
messages = []
|
||||
for agent_action, observation in intermediate_steps:
|
||||
if isinstance(agent_action, ToolAgentAction):
|
||||
new_messages = list(agent_action.message_log) + [
|
||||
_create_tool_message(agent_action, observation)
|
||||
]
|
||||
messages.extend([new for new in new_messages if new not in messages])
|
||||
else:
|
||||
messages.append(AIMessage(content=agent_action.log))
|
||||
return messages
|
@ -20,11 +20,13 @@ from langchain.agents.output_parsers.react_single_input import (
|
||||
ReActSingleInputOutputParser,
|
||||
)
|
||||
from langchain.agents.output_parsers.self_ask import SelfAskOutputParser
|
||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||
from langchain.agents.output_parsers.xml import XMLAgentOutputParser
|
||||
|
||||
__all__ = [
|
||||
"ReActSingleInputOutputParser",
|
||||
"SelfAskOutputParser",
|
||||
"ToolsAgentOutputParser",
|
||||
"ReActJsonSingleInputOutputParser",
|
||||
"OpenAIFunctionsAgentOutputParser",
|
||||
"XMLAgentOutputParser",
|
||||
|
@ -1,70 +1,40 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
|
||||
from langchain.agents.agent import MultiActionAgentOutputParser
|
||||
from langchain.agents.output_parsers.tools import (
|
||||
ToolAgentAction,
|
||||
parse_ai_message_to_tool_action,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIToolAgentAction(AgentActionMessageLog):
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
OpenAIToolAgentAction = ToolAgentAction
|
||||
|
||||
|
||||
def parse_ai_message_to_openai_tool_action(
|
||||
message: BaseMessage,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Parse an AI message potentially containing tool_calls."""
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||
|
||||
if not message.additional_kwargs.get("tool_calls"):
|
||||
return AgentFinish(
|
||||
return_values={"output": message.content}, log=str(message.content)
|
||||
)
|
||||
|
||||
actions: List = []
|
||||
for tool_call in message.additional_kwargs["tool_calls"]:
|
||||
function = tool_call["function"]
|
||||
function_name = function["name"]
|
||||
try:
|
||||
_tool_input = json.loads(function["arguments"] or "{}")
|
||||
except JSONDecodeError:
|
||||
raise OutputParserException(
|
||||
f"Could not parse tool input: {function} because "
|
||||
f"the `arguments` is not valid JSON."
|
||||
)
|
||||
|
||||
# HACK HACK HACK:
|
||||
# The code that encodes tool input into Open AI uses a special variable
|
||||
# name called `__arg1` to handle old style tools that do not expose a
|
||||
# schema and expect a single string argument as an input.
|
||||
# We unpack the argument here if it exists.
|
||||
# Open AI does not support passing in a JSON array as an argument.
|
||||
if "__arg1" in _tool_input:
|
||||
tool_input = _tool_input["__arg1"]
|
||||
else:
|
||||
tool_input = _tool_input
|
||||
|
||||
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
||||
actions.append(
|
||||
tool_actions = parse_ai_message_to_tool_action(message)
|
||||
if isinstance(tool_actions, AgentFinish):
|
||||
return tool_actions
|
||||
final_actions: List[AgentAction] = []
|
||||
for action in tool_actions:
|
||||
if isinstance(action, ToolAgentAction):
|
||||
final_actions.append(
|
||||
OpenAIToolAgentAction(
|
||||
tool=function_name,
|
||||
tool_input=tool_input,
|
||||
log=log,
|
||||
message_log=[message],
|
||||
tool_call_id=tool_call["id"],
|
||||
tool=action.tool,
|
||||
tool_input=action.tool_input,
|
||||
log=action.log,
|
||||
message_log=action.message_log,
|
||||
tool_call_id=action.tool_call_id,
|
||||
)
|
||||
)
|
||||
return actions
|
||||
else:
|
||||
final_actions.append(action)
|
||||
return final_actions
|
||||
|
||||
|
||||
class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
|
102
libs/langchain/langchain/agents/output_parsers/tools.py
Normal file
102
libs/langchain/langchain/agents/output_parsers/tools.py
Normal file
@ -0,0 +1,102 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
|
||||
from langchain.agents.agent import MultiActionAgentOutputParser
|
||||
|
||||
|
||||
class ToolAgentAction(AgentActionMessageLog):
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
|
||||
def parse_ai_message_to_tool_action(
|
||||
message: BaseMessage,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Parse an AI message potentially containing tool_calls."""
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||
|
||||
actions: List = []
|
||||
if message.tool_calls:
|
||||
tool_calls = message.tool_calls
|
||||
else:
|
||||
if not message.additional_kwargs.get("tool_calls"):
|
||||
return AgentFinish(
|
||||
return_values={"output": message.content}, log=str(message.content)
|
||||
)
|
||||
# Best-effort parsing
|
||||
tool_calls = []
|
||||
for tool_call in message.additional_kwargs["tool_calls"]:
|
||||
function = tool_call["function"]
|
||||
function_name = function["name"]
|
||||
try:
|
||||
args = json.loads(function["arguments"] or "{}")
|
||||
tool_calls.append(
|
||||
ToolCall(name=function_name, args=args, id=tool_call["id"])
|
||||
)
|
||||
except JSONDecodeError:
|
||||
raise OutputParserException(
|
||||
f"Could not parse tool input: {function} because "
|
||||
f"the `arguments` is not valid JSON."
|
||||
)
|
||||
for tool_call in tool_calls:
|
||||
# HACK HACK HACK:
|
||||
# The code that encodes tool input into Open AI uses a special variable
|
||||
# name called `__arg1` to handle old style tools that do not expose a
|
||||
# schema and expect a single string argument as an input.
|
||||
# We unpack the argument here if it exists.
|
||||
# Open AI does not support passing in a JSON array as an argument.
|
||||
function_name = tool_call["name"]
|
||||
_tool_input = tool_call["args"]
|
||||
if "__arg1" in _tool_input:
|
||||
tool_input = _tool_input["__arg1"]
|
||||
else:
|
||||
tool_input = _tool_input
|
||||
|
||||
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
||||
actions.append(
|
||||
ToolAgentAction(
|
||||
tool=function_name,
|
||||
tool_input=tool_input,
|
||||
log=log,
|
||||
message_log=[message],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
return actions
|
||||
|
||||
|
||||
class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
"""Parses a message into agent actions/finish.
|
||||
|
||||
If a tool_calls parameter is passed, then that is used to get
|
||||
the tool names and tool inputs.
|
||||
|
||||
If one is not passed, then the AIMessage is assumed to be the final output.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "tools-agent-output-parser"
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError("This output parser only works on ChatGeneration output")
|
||||
message = result[0].message
|
||||
return parse_ai_message_to_tool_action(message)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
96
libs/langchain/langchain/agents/tool_calling_agent/base.py
Normal file
96
libs/langchain/langchain/agents/tool_calling_agent/base.py
Normal file
@ -0,0 +1,96 @@
|
||||
from typing import Sequence
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain.agents.format_scratchpad.tools import (
|
||||
format_to_tool_messages,
|
||||
)
|
||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||
|
||||
|
||||
def create_tool_calling_agent(
|
||||
llm: BaseLanguageModel, tools: Sequence[BaseTool], prompt: ChatPromptTemplate
|
||||
) -> Runnable:
|
||||
"""Create an agent that uses tools.
|
||||
|
||||
Args:
|
||||
llm: LLM to use as the agent.
|
||||
tools: Tools this agent has access to.
|
||||
prompt: The prompt to use. See Prompt section below for more on the expected
|
||||
input variables.
|
||||
|
||||
Returns:
|
||||
A Runnable sequence representing an agent. It takes as input all the same input
|
||||
variables as the prompt passed in does. It returns as output either an
|
||||
AgentAction or AgentFinish.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful assistant"),
|
||||
MessagesPlaceholder("chat_history", optional=True),
|
||||
("human", "{input}"),
|
||||
MessagesPlaceholder("agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
model = ChatAnthropic(model="claude-3-opus-20240229")
|
||||
|
||||
@tool
|
||||
def magic_function(input: int) -> int:
|
||||
\"\"\"Applies a magic function to an input.\"\"\"
|
||||
return input + 2
|
||||
|
||||
tools = [magic_function]
|
||||
|
||||
agent = create_tool_calling_agent(model, tools, prompt)
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
|
||||
|
||||
# Using with chat history
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
agent_executor.invoke(
|
||||
{
|
||||
"input": "what's my name?",
|
||||
"chat_history": [
|
||||
HumanMessage(content="hi! my name is bob"),
|
||||
AIMessage(content="Hello Bob! How can I assist you today?"),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
Prompt:
|
||||
|
||||
The agent prompt must have an `agent_scratchpad` key that is a
|
||||
``MessagesPlaceholder``. Intermediate agent actions and tool output
|
||||
messages will be passed in here.
|
||||
"""
|
||||
missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables)
|
||||
if missing_vars:
|
||||
raise ValueError(f"Prompt missing required variables: {missing_vars}")
|
||||
|
||||
if not hasattr(llm, "bind_tools"):
|
||||
raise ValueError(
|
||||
"This function requires a .bind_tools method be implemented on the LLM.",
|
||||
)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
)
|
||||
| prompt
|
||||
| llm_with_tools
|
||||
| ToolsAgentOutputParser()
|
||||
)
|
||||
return agent
|
@ -1,4 +1,4 @@
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
|
||||
from langchain.agents.format_scratchpad.openai_tools import (
|
||||
format_to_openai_tool_messages,
|
||||
@ -49,16 +49,27 @@ def test_calls_convert_agent_action_to_messages() -> None:
|
||||
}
|
||||
message3 = AIMessage(content="", additional_kwargs=additional_kwargs3)
|
||||
actions3 = parse_ai_message_to_openai_tool_action(message3)
|
||||
|
||||
message4 = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(name="exponentiate", args={"a": 3, "b": 5}, id="call_abc02468")
|
||||
],
|
||||
)
|
||||
actions4 = parse_ai_message_to_openai_tool_action(message4)
|
||||
|
||||
# for mypy
|
||||
assert isinstance(actions1, list)
|
||||
assert isinstance(actions2, list)
|
||||
assert isinstance(actions3, list)
|
||||
assert isinstance(actions4, list)
|
||||
|
||||
intermediate_steps = [
|
||||
(actions1[0], "observation1"),
|
||||
(actions2[0], "observation2"),
|
||||
(actions3[0], "observation3"),
|
||||
(actions3[1], "observation4"),
|
||||
(actions4[0], "observation4"),
|
||||
]
|
||||
expected_messages = [
|
||||
message1,
|
||||
@ -84,6 +95,12 @@ def test_calls_convert_agent_action_to_messages() -> None:
|
||||
content="observation4",
|
||||
additional_kwargs={"name": "divide"},
|
||||
),
|
||||
message4,
|
||||
ToolMessage(
|
||||
tool_call_id="call_abc02468",
|
||||
content="observation4",
|
||||
additional_kwargs={"name": "exponentiate"},
|
||||
),
|
||||
]
|
||||
output = format_to_openai_tool_messages(intermediate_steps)
|
||||
assert output == expected_messages
|
||||
|
@ -16,6 +16,7 @@ from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.prompts import MessagesPlaceholder
|
||||
from langchain_core.runnables.utils import add
|
||||
@ -27,6 +28,7 @@ from langchain.agents import (
|
||||
AgentType,
|
||||
create_openai_functions_agent,
|
||||
create_openai_tools_agent,
|
||||
create_tool_calling_agent,
|
||||
initialize_agent,
|
||||
)
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
|
||||
@ -940,16 +942,20 @@ def _make_tools_invocation(name_to_arguments: Dict[str, Dict[str, Any]]) -> AIMe
|
||||
Returns:
|
||||
AIMessage that represents a request to invoke a tool.
|
||||
"""
|
||||
tool_calls = [
|
||||
raw_tool_calls = [
|
||||
{"function": {"name": name, "arguments": json.dumps(arguments)}, "id": idx}
|
||||
for idx, (name, arguments) in enumerate(name_to_arguments.items())
|
||||
]
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(name=name, args=args, id=str(idx))
|
||||
for idx, (name, args) in enumerate(name_to_arguments.items())
|
||||
]
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"tool_calls": tool_calls,
|
||||
"tool_calls": raw_tool_calls,
|
||||
},
|
||||
tool_calls=tool_calls, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
@ -967,6 +973,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
]
|
||||
)
|
||||
|
||||
GenericFakeChatModel.bind_tools = lambda self, x: self # type: ignore
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
@tool
|
||||
@ -993,11 +1000,17 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
|
||||
# type error due to base tool type below -- would need to be adjusted on tool
|
||||
# decorator.
|
||||
agent = create_openai_tools_agent(
|
||||
openai_agent = create_openai_tools_agent(
|
||||
model,
|
||||
[find_pet], # type: ignore[list-item]
|
||||
template,
|
||||
)
|
||||
tool_calling_agent = create_tool_calling_agent(
|
||||
model,
|
||||
[find_pet], # type: ignore[list-item]
|
||||
template,
|
||||
)
|
||||
for agent in [openai_agent, tool_calling_agent]:
|
||||
executor = AgentExecutor(agent=agent, tools=[find_pet]) # type: ignore[arg-type, list-item]
|
||||
|
||||
# Invoke
|
||||
@ -1009,7 +1022,9 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
|
||||
# astream
|
||||
chunks = [chunk async for chunk in executor.astream({"question": "hello"})]
|
||||
assert chunks == [
|
||||
assert (
|
||||
chunks
|
||||
== [
|
||||
{
|
||||
"actions": [
|
||||
OpenAIToolAgentAction(
|
||||
@ -1057,7 +1072,10 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {"name": "check_time", "arguments": "{}"},
|
||||
"function": {
|
||||
"name": "check_time",
|
||||
"arguments": "{}",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
@ -1112,7 +1130,10 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
"id": 0,
|
||||
},
|
||||
{
|
||||
"function": {"name": "check_time", "arguments": "{}"},
|
||||
"function": {
|
||||
"name": "check_time",
|
||||
"arguments": "{}",
|
||||
},
|
||||
"id": 1,
|
||||
},
|
||||
]
|
||||
@ -1122,14 +1143,16 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
FunctionMessage(content="Spying from under the bed.", name="find_pet")
|
||||
FunctionMessage(
|
||||
content="Spying from under the bed.", name="find_pet"
|
||||
)
|
||||
],
|
||||
"steps": [
|
||||
AgentStep(
|
||||
action=OpenAIToolAgentAction(
|
||||
tool="find_pet",
|
||||
tool_input={"pet": "cat"},
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n", # noqa: E501
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
id=AnyStr(),
|
||||
@ -1163,7 +1186,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
{
|
||||
"messages": [
|
||||
FunctionMessage(
|
||||
content="check_time is not a valid tool, try one of [find_pet].",
|
||||
content="check_time is not a valid tool, try one of [find_pet].", # noqa: E501
|
||||
name="check_time",
|
||||
)
|
||||
],
|
||||
@ -1205,10 +1228,13 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
],
|
||||
},
|
||||
{
|
||||
"messages": [AIMessage(content="The cat is spying from under the bed.")],
|
||||
"messages": [
|
||||
AIMessage(content="The cat is spying from under the bed.")
|
||||
],
|
||||
"output": "The cat is spying from under the bed.",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# astream_log
|
||||
log_patches = [
|
||||
|
@ -43,6 +43,7 @@ EXPECTED_ALL = [
|
||||
"create_self_ask_with_search_agent",
|
||||
"create_json_chat_agent",
|
||||
"create_structured_chat_agent",
|
||||
"create_tool_calling_agent",
|
||||
]
|
||||
|
||||
|
||||
|
@ -42,6 +42,7 @@ _EXPECTED = [
|
||||
"create_self_ask_with_search_agent",
|
||||
"create_json_chat_agent",
|
||||
"create_structured_chat_agent",
|
||||
"create_tool_calling_agent",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user