mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
core[patch]: introduce ToolMessage.status (#24628)
Anthropic models (including via Bedrock and other cloud platforms) accept a status/is_error attribute on tool messages/results (specifically in `tool_result` content blocks for Anthropic API). Adding a ToolMessage.status attribute so that users can set this attribute when using those models
This commit is contained in:
@@ -50,9 +50,6 @@ class ToolMessage(BaseMessage):
|
||||
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
# TODO: Add is_error param?
|
||||
# is_error: bool = False
|
||||
# """Whether the tool errored."""
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
"""The type of the message (used for serialization). Defaults to "tool"."""
|
||||
@@ -67,6 +64,12 @@ class ToolMessage(BaseMessage):
|
||||
.. versionadded:: 0.2.17
|
||||
"""
|
||||
|
||||
status: Literal["success", "error"] = "success"
|
||||
"""Status of the tool invocation.
|
||||
|
||||
.. versionadded:: 0.2.24
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object.
|
||||
@@ -119,6 +122,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
id=self.id,
|
||||
status=_merge_status(self.status, other.status),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
@@ -280,3 +284,9 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
|
||||
)
|
||||
tool_call_chunks.append(parsed)
|
||||
return tool_call_chunks
|
||||
|
||||
|
||||
def _merge_status(
|
||||
left: Literal["success", "error"], right: Literal["success", "error"]
|
||||
) -> Literal["success", "error"]:
|
||||
return "error" if "error" in (left, right) else "success"
|
||||
|
@@ -625,23 +625,27 @@ class ChildTool(BaseTool):
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
|
||||
if error_to_raise:
|
||||
run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
output = _format_output(content, artifact, tool_call_id, self.name)
|
||||
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@@ -737,24 +741,28 @@ class ChildTool(BaseTool):
|
||||
content, artifact = response
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
|
||||
if error_to_raise:
|
||||
await run_manager.on_tool_error(error_to_raise)
|
||||
raise error_to_raise
|
||||
|
||||
output = _format_output(content, artifact, tool_call_id, self.name)
|
||||
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||
return output
|
||||
|
||||
@@ -1511,7 +1519,7 @@ def _prep_run_args(
|
||||
|
||||
|
||||
def _format_output(
|
||||
content: Any, artifact: Any, tool_call_id: Optional[str], name: str
|
||||
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
|
||||
) -> Union[ToolMessage, Any]:
|
||||
if tool_call_id:
|
||||
# NOTE: This will fail to stringify lists which aren't actually content blocks
|
||||
@@ -1524,7 +1532,11 @@ def _format_output(
|
||||
):
|
||||
content = _stringify(content)
|
||||
return ToolMessage(
|
||||
content, artifact=artifact, tool_call_id=tool_call_id, name=name
|
||||
content,
|
||||
artifact=artifact,
|
||||
tool_call_id=tool_call_id,
|
||||
name=name,
|
||||
status=status,
|
||||
)
|
||||
else:
|
||||
return content
|
||||
|
@@ -524,6 +524,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -1132,6 +1141,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
|
@@ -880,6 +880,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
|
@@ -5806,6 +5806,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -6483,6 +6492,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -7075,6 +7093,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -7724,6 +7751,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -8376,6 +8412,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -8975,6 +9020,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -9547,6 +9601,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
@@ -10228,6 +10291,15 @@
|
||||
'title': 'Response Metadata',
|
||||
'type': 'object',
|
||||
}),
|
||||
'status': dict({
|
||||
'default': 'success',
|
||||
'enum': list([
|
||||
'success',
|
||||
'error',
|
||||
]),
|
||||
'title': 'Status',
|
||||
'type': 'string',
|
||||
}),
|
||||
'tool_call_id': dict({
|
||||
'title': 'Tool Call Id',
|
||||
'type': 'string',
|
||||
|
@@ -874,7 +874,9 @@ def test_merge_tool_calls() -> None:
|
||||
|
||||
|
||||
def test_tool_message_serdes() -> None:
|
||||
message = ToolMessage("foo", artifact={"bar": {"baz": 123}}, tool_call_id="1")
|
||||
message = ToolMessage(
|
||||
"foo", artifact={"bar": {"baz": 123}}, tool_call_id="1", status="error"
|
||||
)
|
||||
ser_message = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
@@ -884,6 +886,7 @@ def test_tool_message_serdes() -> None:
|
||||
"type": "tool",
|
||||
"tool_call_id": "1",
|
||||
"artifact": {"bar": {"baz": 123}},
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
assert dumpd(message) == ser_message
|
||||
@@ -911,6 +914,7 @@ def test_tool_message_ser_non_serializable() -> None:
|
||||
"id": ["tests", "unit_tests", "test_messages", "BadObject"],
|
||||
"repr": repr(bad_obj),
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
assert dumpd(message) == ser_message
|
||||
@@ -931,6 +935,7 @@ def test_tool_message_to_dict() -> None:
|
||||
"name": None,
|
||||
"id": None,
|
||||
"tool_call_id": "1",
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
actual = message_to_dict(message)
|
||||
|
Reference in New Issue
Block a user