mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
Fix OAI Tool Message (#14746)
See format here: https://platform.openai.com/docs/guides/function-calling/parallel-function-calling It expects a "name" argument, which we aren't providing by default.  Alternative is to add the 'name' field directly to the message if people prefer.
This commit is contained in:
@@ -89,7 +89,14 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|||||||
elif role == "function":
|
elif role == "function":
|
||||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"])
|
additional_kwargs = {}
|
||||||
|
if "name" in _dict:
|
||||||
|
additional_kwargs["name"] = _dict["name"]
|
||||||
|
return ToolMessage(
|
||||||
|
content=_dict["content"],
|
||||||
|
tool_call_id=_dict["tool_call_id"],
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return ChatMessage(content=_dict["content"], role=role)
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
|
@@ -31,6 +31,7 @@ def _create_tool_message(
|
|||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
tool_call_id=agent_action.tool_call_id,
|
tool_call_id=agent_action.tool_call_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
additional_kwargs={"name": agent_action.tool},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -0,0 +1,94 @@
|
|||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from langchain.agents.format_scratchpad.openai_tools import (
|
||||||
|
format_to_openai_tool_messages,
|
||||||
|
)
|
||||||
|
from langchain.agents.output_parsers.openai_tools import (
|
||||||
|
parse_ai_message_to_openai_tool_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calls_convert_agent_action_to_messages() -> None:
|
||||||
|
additional_kwargs1 = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abcd12345",
|
||||||
|
"function": {"arguments": '{"a": 3, "b": 5}', "name": "add"},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
message1 = AIMessage(content="", additional_kwargs=additional_kwargs1)
|
||||||
|
|
||||||
|
actions1 = parse_ai_message_to_openai_tool_action(message1)
|
||||||
|
additional_kwargs2 = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abcd54321",
|
||||||
|
"function": {"arguments": '{"a": 3, "b": 5}', "name": "subtract"},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
message2 = AIMessage(content="", additional_kwargs=additional_kwargs2)
|
||||||
|
actions2 = parse_ai_message_to_openai_tool_action(message2)
|
||||||
|
|
||||||
|
additional_kwargs3 = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abcd67890",
|
||||||
|
"function": {"arguments": '{"a": 3, "b": 5}', "name": "multiply"},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_abcd09876",
|
||||||
|
"function": {"arguments": '{"a": 3, "b": 5}', "name": "divide"},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
message3 = AIMessage(content="", additional_kwargs=additional_kwargs3)
|
||||||
|
actions3 = parse_ai_message_to_openai_tool_action(message3)
|
||||||
|
# for mypy
|
||||||
|
assert isinstance(actions1, list)
|
||||||
|
assert isinstance(actions2, list)
|
||||||
|
assert isinstance(actions3, list)
|
||||||
|
|
||||||
|
intermediate_steps = [
|
||||||
|
(actions1[0], "observation1"),
|
||||||
|
(actions2[0], "observation2"),
|
||||||
|
(actions3[0], "observation3"),
|
||||||
|
(actions3[1], "observation4"),
|
||||||
|
]
|
||||||
|
expected_messages = [
|
||||||
|
message1,
|
||||||
|
ToolMessage(
|
||||||
|
tool_call_id="call_abcd12345",
|
||||||
|
content="observation1",
|
||||||
|
additional_kwargs={"name": "add"},
|
||||||
|
),
|
||||||
|
message2,
|
||||||
|
ToolMessage(
|
||||||
|
tool_call_id="call_abcd54321",
|
||||||
|
content="observation2",
|
||||||
|
additional_kwargs={"name": "subtract"},
|
||||||
|
),
|
||||||
|
message3,
|
||||||
|
ToolMessage(
|
||||||
|
tool_call_id="call_abcd67890",
|
||||||
|
content="observation3",
|
||||||
|
additional_kwargs={"name": "multiply"},
|
||||||
|
),
|
||||||
|
ToolMessage(
|
||||||
|
tool_call_id="call_abcd09876",
|
||||||
|
content="observation4",
|
||||||
|
additional_kwargs={"name": "divide"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
output = format_to_openai_tool_messages(intermediate_steps)
|
||||||
|
assert output == expected_messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_empty_input_list() -> None:
|
||||||
|
output = format_to_openai_tool_messages([])
|
||||||
|
assert output == []
|
Reference in New Issue
Block a user