core[patch]: introduce ToolMessage.status (#24628)

Anthropic models (including via Bedrock and other cloud platforms)
accept a status/is_error attribute on tool messages/results
(specifically in `tool_result` content blocks for Anthropic API). Adding
a ToolMessage.status attribute so that users can set this attribute when
using those models
This commit is contained in:
Bagatur 2024-07-29 14:01:53 -07:00 committed by GitHub
parent 78d97b49d9
commit a6d1fb4275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 174 additions and 15 deletions

View File

@ -50,9 +50,6 @@ class ToolMessage(BaseMessage):
tool_call_id: str
"""Tool call that this message is responding to."""
# TODO: Add is_error param?
# is_error: bool = False
# """Whether the tool errored."""
type: Literal["tool"] = "tool"
"""The type of the message (used for serialization). Defaults to "tool"."""
@ -67,6 +64,12 @@ class ToolMessage(BaseMessage):
.. versionadded:: 0.2.17
"""
status: Literal["success", "error"] = "success"
"""Status of the tool invocation.
.. versionadded:: 0.2.24
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.
@ -119,6 +122,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
self.response_metadata, other.response_metadata
),
id=self.id,
status=_merge_status(self.status, other.status),
)
return super().__add__(other)
@ -280,3 +284,9 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
)
tool_call_chunks.append(parsed)
return tool_call_chunks
def _merge_status(
left: Literal["success", "error"], right: Literal["success", "error"]
) -> Literal["success", "error"]:
return "error" if "error" in (left, right) else "success"

View File

@ -625,23 +625,27 @@ class ChildTool(BaseTool):
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name)
output = _format_output(content, artifact, tool_call_id, self.name, status)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@ -737,24 +741,28 @@ class ChildTool(BaseTool):
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
await run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name)
output = _format_output(content, artifact, tool_call_id, self.name, status)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output
@ -1511,7 +1519,7 @@ def _prep_run_args(
def _format_output(
content: Any, artifact: Any, tool_call_id: Optional[str], name: str
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
) -> Union[ToolMessage, Any]:
if tool_call_id:
# NOTE: This will fail to stringify lists which aren't actually content blocks
@ -1524,7 +1532,11 @@ def _format_output(
):
content = _stringify(content)
return ToolMessage(
content, artifact=artifact, tool_call_id=tool_call_id, name=name
content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
)
else:
return content

View File

@ -524,6 +524,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -1132,6 +1141,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',

View File

@ -880,6 +880,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',

View File

@ -5806,6 +5806,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -6483,6 +6492,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -7075,6 +7093,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -7724,6 +7751,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -8376,6 +8412,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -8975,6 +9020,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -9547,6 +9601,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
@ -10228,6 +10291,15 @@
'title': 'Response Metadata',
'type': 'object',
}),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',

View File

@ -874,7 +874,9 @@ def test_merge_tool_calls() -> None:
def test_tool_message_serdes() -> None:
message = ToolMessage("foo", artifact={"bar": {"baz": 123}}, tool_call_id="1")
message = ToolMessage(
"foo", artifact={"bar": {"baz": 123}}, tool_call_id="1", status="error"
)
ser_message = {
"lc": 1,
"type": "constructor",
@ -884,6 +886,7 @@ def test_tool_message_serdes() -> None:
"type": "tool",
"tool_call_id": "1",
"artifact": {"bar": {"baz": 123}},
"status": "error",
},
}
assert dumpd(message) == ser_message
@ -911,6 +914,7 @@ def test_tool_message_ser_non_serializable() -> None:
"id": ["tests", "unit_tests", "test_messages", "BadObject"],
"repr": repr(bad_obj),
},
"status": "success",
},
}
assert dumpd(message) == ser_message
@ -931,6 +935,7 @@ def test_tool_message_to_dict() -> None:
"name": None,
"id": None,
"tool_call_id": "1",
"status": "success",
},
}
actual = message_to_dict(message)

View File

@ -127,6 +127,7 @@ def _merge_messages(
"type": "tool_result",
"content": curr.content,
"tool_use_id": curr.tool_call_id,
"is_error": curr.status == "error",
}
]
)

View File

@ -139,7 +139,7 @@ def test__merge_messages() -> None:
},
]
),
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
ToolMessage("buz output", tool_call_id="1", status="error"), # type: ignore[misc]
ToolMessage(
content=[
{
@ -180,7 +180,12 @@ def test__merge_messages() -> None:
),
HumanMessage( # type: ignore[misc]
[
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
{
"type": "tool_result",
"content": "buz output",
"tool_use_id": "1",
"is_error": True,
},
{
"type": "tool_result",
"content": [
@ -194,6 +199,7 @@ def test__merge_messages() -> None:
},
],
"tool_use_id": "2",
"is_error": False,
},
{"type": "text", "text": "next thing"},
]
@ -215,7 +221,12 @@ def test__merge_messages() -> None:
expected = [
HumanMessage( # type: ignore[misc]
[
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
{
"type": "tool_result",
"content": "buz output",
"tool_use_id": "1",
"is_error": False,
},
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
]
)
@ -382,7 +393,12 @@ def test__format_messages_with_tool_calls() -> None:
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
{
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
],
},
],
@ -421,7 +437,12 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
{
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
],
},
],
@ -455,7 +476,12 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
{
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
],
},
],
@ -502,7 +528,12 @@ def test__format_messages_with_tool_use_blocks_and_tool_calls() -> None:
{
"role": "user",
"content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
{
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
],
},
],

View File

@ -467,6 +467,7 @@ class ChatModelIntegrationTests(ChatModelTests):
"text": "green is a great pick! that's my sister's favorite color", # noqa: E501
}
],
"is_error": False,
},
{"type": "text", "text": "what's my sister's favorite color"},
]