core[patch]: add tool name to tool message (#24243)

Copying current ToolNode behavior
This commit is contained in:
Bagatur 2024-07-14 17:42:40 -07:00 committed by GitHub
parent 9224027e45
commit d0728b0ba0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 6 deletions

View File

@ -580,7 +580,7 @@ class ChildTool(BaseTool):
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id)
output = _format_output(content, artifact, tool_call_id, self.name)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@ -672,7 +672,7 @@ class ChildTool(BaseTool):
await run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id)
output = _format_output(content, artifact, tool_call_id, self.name)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@ -1385,7 +1385,7 @@ def _prep_run_args(
def _format_output(
content: Any, artifact: Any, tool_call_id: Optional[str]
content: Any, artifact: Any, tool_call_id: Optional[str], name: str
) -> Union[ToolMessage, Any]:
if tool_call_id:
# NOTE: This will fail to stringify lists which aren't actually content blocks
@ -1397,7 +1397,9 @@ def _format_output(
and isinstance(content[0], (str, dict))
):
content = _stringify(content)
return ToolMessage(content, artifact=artifact, tool_call_id=tool_call_id)
return ToolMessage(
content, artifact=artifact, tool_call_id=tool_call_id, name=name
)
else:
return content

View File

@ -1137,7 +1137,9 @@ def test_tool_call_input_tool_message_output() -> None:
"type": "tool_call",
}
tool = _MockStructuredTool()
expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123")
expected = ToolMessage(
"1 True {'img': 'base64string...'}", tool_call_id="123", name="structured_api"
)
actual = tool.invoke(tool_call)
assert actual == expected
@ -1176,7 +1178,9 @@ def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None:
"id": "123",
"type": "tool_call",
}
expected = ToolMessage("1 True", artifact=tool_call["args"], tool_call_id="123")
expected = ToolMessage(
"1 True", artifact=tool_call["args"], tool_call_id="123", name="structured_api"
)
actual = tool.invoke(tool_call)
assert actual == expected