core[patch]: introduce ToolMessage.status (#24628)

Anthropic models (including via Bedrock and other cloud platforms)
accept a status/is_error attribute on tool messages/results
(specifically in `tool_result` content blocks for Anthropic API). Adding
a ToolMessage.status attribute so that users can set this attribute when
using those models
This commit is contained in:
Bagatur 2024-07-29 14:01:53 -07:00 committed by GitHub
parent 78d97b49d9
commit a6d1fb4275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 174 additions and 15 deletions

View File

@ -50,9 +50,6 @@ class ToolMessage(BaseMessage):
tool_call_id: str tool_call_id: str
"""Tool call that this message is responding to.""" """Tool call that this message is responding to."""
# TODO: Add is_error param?
# is_error: bool = False
# """Whether the tool errored."""
type: Literal["tool"] = "tool" type: Literal["tool"] = "tool"
"""The type of the message (used for serialization). Defaults to "tool".""" """The type of the message (used for serialization). Defaults to "tool"."""
@ -67,6 +64,12 @@ class ToolMessage(BaseMessage):
.. versionadded:: 0.2.17 .. versionadded:: 0.2.17
""" """
status: Literal["success", "error"] = "success"
"""Status of the tool invocation.
.. versionadded:: 0.2.24
"""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
@ -119,6 +122,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
self.response_metadata, other.response_metadata self.response_metadata, other.response_metadata
), ),
id=self.id, id=self.id,
status=_merge_status(self.status, other.status),
) )
return super().__add__(other) return super().__add__(other)
@ -280,3 +284,9 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
) )
tool_call_chunks.append(parsed) tool_call_chunks.append(parsed)
return tool_call_chunks return tool_call_chunks
def _merge_status(
left: Literal["success", "error"], right: Literal["success", "error"]
) -> Literal["success", "error"]:
return "error" if "error" in (left, right) else "success"

View File

@ -625,23 +625,27 @@ class ChildTool(BaseTool):
content, artifact = response content, artifact = response
else: else:
content = response content = response
status = "success"
except ValidationError as e: except ValidationError as e:
if not self.handle_validation_error: if not self.handle_validation_error:
error_to_raise = e error_to_raise = e
else: else:
content = _handle_validation_error(e, flag=self.handle_validation_error) content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
error_to_raise = e error_to_raise = e
else: else:
content = _handle_tool_error(e, flag=self.handle_tool_error) content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
error_to_raise = e error_to_raise = e
status = "error"
if error_to_raise: if error_to_raise:
run_manager.on_tool_error(error_to_raise) run_manager.on_tool_error(error_to_raise)
raise error_to_raise raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name) output = _format_output(content, artifact, tool_call_id, self.name, status)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output return output
@ -737,24 +741,28 @@ class ChildTool(BaseTool):
content, artifact = response content, artifact = response
else: else:
content = response content = response
status = "success"
except ValidationError as e: except ValidationError as e:
if not self.handle_validation_error: if not self.handle_validation_error:
error_to_raise = e error_to_raise = e
else: else:
content = _handle_validation_error(e, flag=self.handle_validation_error) content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e: except ToolException as e:
if not self.handle_tool_error: if not self.handle_tool_error:
error_to_raise = e error_to_raise = e
else: else:
content = _handle_tool_error(e, flag=self.handle_tool_error) content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
error_to_raise = e error_to_raise = e
status = "error"
if error_to_raise: if error_to_raise:
await run_manager.on_tool_error(error_to_raise) await run_manager.on_tool_error(error_to_raise)
raise error_to_raise raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name) output = _format_output(content, artifact, tool_call_id, self.name, status)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
return output return output
@ -1511,7 +1519,7 @@ def _prep_run_args(
def _format_output( def _format_output(
content: Any, artifact: Any, tool_call_id: Optional[str], name: str content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
) -> Union[ToolMessage, Any]: ) -> Union[ToolMessage, Any]:
if tool_call_id: if tool_call_id:
# NOTE: This will fail to stringify lists which aren't actually content blocks # NOTE: This will fail to stringify lists which aren't actually content blocks
@ -1524,7 +1532,11 @@ def _format_output(
): ):
content = _stringify(content) content = _stringify(content)
return ToolMessage( return ToolMessage(
content, artifact=artifact, tool_call_id=tool_call_id, name=name content,
artifact=artifact,
tool_call_id=tool_call_id,
name=name,
status=status,
) )
else: else:
return content return content

View File

@ -524,6 +524,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -1132,6 +1141,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',

View File

@ -880,6 +880,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',

View File

@ -5806,6 +5806,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -6483,6 +6492,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -7075,6 +7093,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -7724,6 +7751,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -8376,6 +8412,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -8975,6 +9020,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -9547,6 +9601,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',
@ -10228,6 +10291,15 @@
'title': 'Response Metadata', 'title': 'Response Metadata',
'type': 'object', 'type': 'object',
}), }),
'status': dict({
'default': 'success',
'enum': list([
'success',
'error',
]),
'title': 'Status',
'type': 'string',
}),
'tool_call_id': dict({ 'tool_call_id': dict({
'title': 'Tool Call Id', 'title': 'Tool Call Id',
'type': 'string', 'type': 'string',

View File

@ -874,7 +874,9 @@ def test_merge_tool_calls() -> None:
def test_tool_message_serdes() -> None: def test_tool_message_serdes() -> None:
message = ToolMessage("foo", artifact={"bar": {"baz": 123}}, tool_call_id="1") message = ToolMessage(
"foo", artifact={"bar": {"baz": 123}}, tool_call_id="1", status="error"
)
ser_message = { ser_message = {
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",
@ -884,6 +886,7 @@ def test_tool_message_serdes() -> None:
"type": "tool", "type": "tool",
"tool_call_id": "1", "tool_call_id": "1",
"artifact": {"bar": {"baz": 123}}, "artifact": {"bar": {"baz": 123}},
"status": "error",
}, },
} }
assert dumpd(message) == ser_message assert dumpd(message) == ser_message
@ -911,6 +914,7 @@ def test_tool_message_ser_non_serializable() -> None:
"id": ["tests", "unit_tests", "test_messages", "BadObject"], "id": ["tests", "unit_tests", "test_messages", "BadObject"],
"repr": repr(bad_obj), "repr": repr(bad_obj),
}, },
"status": "success",
}, },
} }
assert dumpd(message) == ser_message assert dumpd(message) == ser_message
@ -931,6 +935,7 @@ def test_tool_message_to_dict() -> None:
"name": None, "name": None,
"id": None, "id": None,
"tool_call_id": "1", "tool_call_id": "1",
"status": "success",
}, },
} }
actual = message_to_dict(message) actual = message_to_dict(message)

View File

@ -127,6 +127,7 @@ def _merge_messages(
"type": "tool_result", "type": "tool_result",
"content": curr.content, "content": curr.content,
"tool_use_id": curr.tool_call_id, "tool_use_id": curr.tool_call_id,
"is_error": curr.status == "error",
} }
] ]
) )

View File

@ -139,7 +139,7 @@ def test__merge_messages() -> None:
}, },
] ]
), ),
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] ToolMessage("buz output", tool_call_id="1", status="error"), # type: ignore[misc]
ToolMessage( ToolMessage(
content=[ content=[
{ {
@ -180,7 +180,12 @@ def test__merge_messages() -> None:
), ),
HumanMessage( # type: ignore[misc] HumanMessage( # type: ignore[misc]
[ [
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, {
"type": "tool_result",
"content": "buz output",
"tool_use_id": "1",
"is_error": True,
},
{ {
"type": "tool_result", "type": "tool_result",
"content": [ "content": [
@ -194,6 +199,7 @@ def test__merge_messages() -> None:
}, },
], ],
"tool_use_id": "2", "tool_use_id": "2",
"is_error": False,
}, },
{"type": "text", "text": "next thing"}, {"type": "text", "text": "next thing"},
] ]
@ -215,7 +221,12 @@ def test__merge_messages() -> None:
expected = [ expected = [
HumanMessage( # type: ignore[misc] HumanMessage( # type: ignore[misc]
[ [
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, {
"type": "tool_result",
"content": "buz output",
"tool_use_id": "1",
"is_error": False,
},
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, {"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
] ]
) )
@ -382,7 +393,12 @@ def test__format_messages_with_tool_calls() -> None:
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"} {
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
], ],
}, },
], ],
@ -421,7 +437,12 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"} {
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
], ],
}, },
], ],
@ -455,7 +476,12 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"} {
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
], ],
}, },
], ],
@ -502,7 +528,12 @@ def test__format_messages_with_tool_use_blocks_and_tool_calls() -> None:
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"} {
"type": "tool_result",
"content": "blurb",
"tool_use_id": "1",
"is_error": False,
}
], ],
}, },
], ],

View File

@ -467,6 +467,7 @@ class ChatModelIntegrationTests(ChatModelTests):
"text": "green is a great pick! that's my sister's favorite color", # noqa: E501 "text": "green is a great pick! that's my sister's favorite color", # noqa: E501
} }
], ],
"is_error": False,
}, },
{"type": "text", "text": "what's my sister's favorite color"}, {"type": "text", "text": "what's my sister's favorite color"},
] ]