mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
core[patch]: introduce ToolMessage.status (#24628)
Anthropic models (including via Bedrock and other cloud platforms) accept a status/is_error attribute on tool messages/results (specifically in `tool_result` content blocks for Anthropic API). Adding a ToolMessage.status attribute so that users can set this attribute when using those models
This commit is contained in:
parent
78d97b49d9
commit
a6d1fb4275
@ -50,9 +50,6 @@ class ToolMessage(BaseMessage):
|
||||
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
# TODO: Add is_error param?
|
||||
# is_error: bool = False
|
||||
# """Whether the tool errored."""
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
"""The type of the message (used for serialization). Defaults to "tool"."""
|
||||
@ -67,6 +64,12 @@ class ToolMessage(BaseMessage):
|
||||
.. versionadded:: 0.2.17
|
||||
"""
|
||||
|
||||
status: Literal["success", "error"] = "success"
|
||||
"""Status of the tool invocation.
|
||||
|
||||
.. versionadded:: 0.2.24
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
@ -119,6 +122,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
id=self.id,
|
||||
status=_merge_status(self.status, other.status),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
@ -280,3 +284,9 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
|
||||
)
|
||||
tool_call_chunks.append(parsed)
|
||||
return tool_call_chunks
|
||||
|
||||
|
||||
def _merge_status(
|
||||
left: Literal["success", "error"], right: Literal["success", "error"]
|
||||
) -> Literal["success", "error"]:
|
||||
return "error" if "error" in (left, right) else "success"
|
||||
|
@ -625,23 +625,27 @@ class ChildTool(BaseTool):
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
|
||||
if error_to_raise:
|
||||
run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
output = _format_output(content, artifact, tool_call_id, self.name)
|
||||
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@ -737,24 +741,28 @@ class ChildTool(BaseTool):
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
|
||||
if error_to_raise:
|
||||
await run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
|
||||
output = _format_output(content, artifact, tool_call_id, self.name)
|
||||
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@ -1511,7 +1519,7 @@ def _prep_run_args(
|
||||
|
||||
|
||||
def _format_output(
|
||||
content: Any, artifact: Any, tool_call_id: Optional[str], name: str
|
||||
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
|
||||
) -> Union[ToolMessage, Any]:
|
||||
if tool_call_id:
|
||||
# NOTE: This will fail to stringify lists which aren't actually content blocks
|
||||
@ -1524,7 +1532,11 @@ def _format_output(
|
||||
):
|
||||
content = _stringify(content)
|
||||
return ToolMessage(
|
||||
content, artifact=artifact, tool_call_id=tool_call_id, name=name
|
||||
content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
name=name,
|
||||
status=status,
|
||||
)
|
||||
else:
|
||||
return content
|
||||
|
@ -524,6 +524,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -1132,6 +1141,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
|
@ -880,6 +880,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
|
@ -5806,6 +5806,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -6483,6 +6492,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -7075,6 +7093,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -7724,6 +7751,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -8376,6 +8412,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -8975,6 +9020,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -9547,6 +9601,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@ -10228,6 +10291,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
|
@ -874,7 +874,9 @@ def test_merge_tool_calls() -> None:
|
||||
|
||||
|
||||
def test_tool_message_serdes() -> None:
|
||||
message = ToolMessage("foo", artifact={"bar": {"baz": 123}}, tool_call_id="1")
|
||||
message = ToolMessage(
|
||||
"foo", artifact={"bar": {"baz": 123}}, tool_call_id="1", status="error"
|
||||
)
|
||||
ser_message = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
@ -884,6 +886,7 @@ def test_tool_message_serdes() -> None:
|
||||
"type": "tool",
|
||||
"tool_call_id": "1",
|
||||
"artifact": {"bar": {"baz": 123}},
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
assert dumpd(message) == ser_message
|
||||
@ -911,6 +914,7 @@ def test_tool_message_ser_non_serializable() -> None:
|
||||
"id": ["tests", "unit_tests", "test_messages", "BadObject"],
|
||||
"repr": repr(bad_obj),
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
assert dumpd(message) == ser_message
|
||||
@ -931,6 +935,7 @@ def test_tool_message_to_dict() -> None:
|
||||
"name": None,
|
||||
"id": None,
|
||||
"tool_call_id": "1",
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
actual = message_to_dict(message)
|
||||
|
@ -127,6 +127,7 @@ def _merge_messages(
|
||||
"type": "tool_result",
|
||||
"content": curr.content,
|
||||
"tool_use_id": curr.tool_call_id,
|
||||
"is_error": curr.status == "error",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
@ -139,7 +139,7 @@ def test__merge_messages() -> None:
|
||||
},
|
||||
]
|
||||
),
|
||||
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
|
||||
ToolMessage("buz output", tool_call_id="1", status="error"), # type: ignore[misc]
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
@ -180,7 +180,12 @@ def test__merge_messages() -> None:
|
||||
),
|
||||
HumanMessage( # type: ignore[misc]
|
||||
[
|
||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": "buz output",
|
||||
"tool_use_id": "1",
|
||||
"is_error": True,
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": [
|
||||
@ -194,6 +199,7 @@ def test__merge_messages() -> None:
|
||||
},
|
||||
],
|
||||
"tool_use_id": "2",
|
||||
"is_error": False,
|
||||
},
|
||||
{"type": "text", "text": "next thing"},
|
||||
]
|
||||
@ -215,7 +221,12 @@ def test__merge_messages() -> None:
|
||||
expected = [
|
||||
HumanMessage( # type: ignore[misc]
|
||||
[
|
||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": "buz output",
|
||||
"tool_use_id": "1",
|
||||
"is_error": False,
|
||||
},
|
||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||
]
|
||||
)
|
||||
@ -382,7 +393,12 @@ def test__format_messages_with_tool_calls() -> None:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": "blurb",
|
||||
"tool_use_id": "1",
|
||||
"is_error": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
@ -421,7 +437,12 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": "blurb",
|
||||
"tool_use_id": "1",
|
||||
"is_error": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
@ -455,7 +476,12 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": "blurb",
|
||||
"tool_use_id": "1",
|
||||
"is_error": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
@ -502,7 +528,12 @@ def test__format_messages_with_tool_use_blocks_and_tool_calls() -> None:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": "blurb",
|
||||
"tool_use_id": "1",
|
||||
"is_error": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
|
@ -467,6 +467,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
"text": "green is a great pick! that's my sister's favorite color", # noqa: E501
|
||||
}
|
||||
],
|
||||
"is_error": False,
|
||||
},
|
||||
{"type": "text", "text": "what's my sister's favorite color"},
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user