mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
core[patch]: introduce ToolMessage.status (#24628)
Anthropic models (including via Bedrock and other cloud platforms) accept a status/is_error attribute on tool messages/results (specifically in `tool_result` content blocks for Anthropic API). Adding a ToolMessage.status attribute so that users can set this attribute when using those models
This commit is contained in:
parent
78d97b49d9
commit
a6d1fb4275
@ -50,9 +50,6 @@ class ToolMessage(BaseMessage):
|
|||||||
|
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
"""Tool call that this message is responding to."""
|
"""Tool call that this message is responding to."""
|
||||||
# TODO: Add is_error param?
|
|
||||||
# is_error: bool = False
|
|
||||||
# """Whether the tool errored."""
|
|
||||||
|
|
||||||
type: Literal["tool"] = "tool"
|
type: Literal["tool"] = "tool"
|
||||||
"""The type of the message (used for serialization). Defaults to "tool"."""
|
"""The type of the message (used for serialization). Defaults to "tool"."""
|
||||||
@ -67,6 +64,12 @@ class ToolMessage(BaseMessage):
|
|||||||
.. versionadded:: 0.2.17
|
.. versionadded:: 0.2.17
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
status: Literal["success", "error"] = "success"
|
||||||
|
"""Status of the tool invocation.
|
||||||
|
|
||||||
|
.. versionadded:: 0.2.24
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
"""Get the namespace of the langchain object.
|
"""Get the namespace of the langchain object.
|
||||||
@ -119,6 +122,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
|||||||
self.response_metadata, other.response_metadata
|
self.response_metadata, other.response_metadata
|
||||||
),
|
),
|
||||||
id=self.id,
|
id=self.id,
|
||||||
|
status=_merge_status(self.status, other.status),
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().__add__(other)
|
return super().__add__(other)
|
||||||
@ -280,3 +284,9 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
|
|||||||
)
|
)
|
||||||
tool_call_chunks.append(parsed)
|
tool_call_chunks.append(parsed)
|
||||||
return tool_call_chunks
|
return tool_call_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_status(
|
||||||
|
left: Literal["success", "error"], right: Literal["success", "error"]
|
||||||
|
) -> Literal["success", "error"]:
|
||||||
|
return "error" if "error" in (left, right) else "success"
|
||||||
|
@ -625,23 +625,27 @@ class ChildTool(BaseTool):
|
|||||||
content, artifact = response
|
content, artifact = response
|
||||||
else:
|
else:
|
||||||
content = response
|
content = response
|
||||||
|
status = "success"
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if not self.handle_validation_error:
|
if not self.handle_validation_error:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
else:
|
||||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||||
|
status = "error"
|
||||||
except ToolException as e:
|
except ToolException as e:
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
else:
|
||||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
|
status = "error"
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
|
status = "error"
|
||||||
|
|
||||||
if error_to_raise:
|
if error_to_raise:
|
||||||
run_manager.on_tool_error(error_to_raise)
|
run_manager.on_tool_error(error_to_raise)
|
||||||
raise error_to_raise
|
raise error_to_raise
|
||||||
output = _format_output(content, artifact, tool_call_id, self.name)
|
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||||
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -737,24 +741,28 @@ class ChildTool(BaseTool):
|
|||||||
content, artifact = response
|
content, artifact = response
|
||||||
else:
|
else:
|
||||||
content = response
|
content = response
|
||||||
|
status = "success"
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if not self.handle_validation_error:
|
if not self.handle_validation_error:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
else:
|
||||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||||
|
status = "error"
|
||||||
except ToolException as e:
|
except ToolException as e:
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
else:
|
else:
|
||||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
|
status = "error"
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
error_to_raise = e
|
error_to_raise = e
|
||||||
|
status = "error"
|
||||||
|
|
||||||
if error_to_raise:
|
if error_to_raise:
|
||||||
await run_manager.on_tool_error(error_to_raise)
|
await run_manager.on_tool_error(error_to_raise)
|
||||||
raise error_to_raise
|
raise error_to_raise
|
||||||
|
|
||||||
output = _format_output(content, artifact, tool_call_id, self.name)
|
output = _format_output(content, artifact, tool_call_id, self.name, status)
|
||||||
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -1511,7 +1519,7 @@ def _prep_run_args(
|
|||||||
|
|
||||||
|
|
||||||
def _format_output(
|
def _format_output(
|
||||||
content: Any, artifact: Any, tool_call_id: Optional[str], name: str
|
content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str
|
||||||
) -> Union[ToolMessage, Any]:
|
) -> Union[ToolMessage, Any]:
|
||||||
if tool_call_id:
|
if tool_call_id:
|
||||||
# NOTE: This will fail to stringify lists which aren't actually content blocks
|
# NOTE: This will fail to stringify lists which aren't actually content blocks
|
||||||
@ -1524,7 +1532,11 @@ def _format_output(
|
|||||||
):
|
):
|
||||||
content = _stringify(content)
|
content = _stringify(content)
|
||||||
return ToolMessage(
|
return ToolMessage(
|
||||||
content, artifact=artifact, tool_call_id=tool_call_id, name=name
|
content,
|
||||||
|
artifact=artifact,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
name=name,
|
||||||
|
status=status,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return content
|
return content
|
||||||
|
@ -524,6 +524,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -1132,6 +1141,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
|
@ -880,6 +880,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
|
@ -5806,6 +5806,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -6483,6 +6492,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -7075,6 +7093,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -7724,6 +7751,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -8376,6 +8412,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -8975,6 +9020,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -9547,6 +9601,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
@ -10228,6 +10291,15 @@
|
|||||||
'title': 'Response Metadata',
|
'title': 'Response Metadata',
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
}),
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'default': 'success',
|
||||||
|
'enum': list([
|
||||||
|
'success',
|
||||||
|
'error',
|
||||||
|
]),
|
||||||
|
'title': 'Status',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
'tool_call_id': dict({
|
'tool_call_id': dict({
|
||||||
'title': 'Tool Call Id',
|
'title': 'Tool Call Id',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
|
@ -874,7 +874,9 @@ def test_merge_tool_calls() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_tool_message_serdes() -> None:
|
def test_tool_message_serdes() -> None:
|
||||||
message = ToolMessage("foo", artifact={"bar": {"baz": 123}}, tool_call_id="1")
|
message = ToolMessage(
|
||||||
|
"foo", artifact={"bar": {"baz": 123}}, tool_call_id="1", status="error"
|
||||||
|
)
|
||||||
ser_message = {
|
ser_message = {
|
||||||
"lc": 1,
|
"lc": 1,
|
||||||
"type": "constructor",
|
"type": "constructor",
|
||||||
@ -884,6 +886,7 @@ def test_tool_message_serdes() -> None:
|
|||||||
"type": "tool",
|
"type": "tool",
|
||||||
"tool_call_id": "1",
|
"tool_call_id": "1",
|
||||||
"artifact": {"bar": {"baz": 123}},
|
"artifact": {"bar": {"baz": 123}},
|
||||||
|
"status": "error",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert dumpd(message) == ser_message
|
assert dumpd(message) == ser_message
|
||||||
@ -911,6 +914,7 @@ def test_tool_message_ser_non_serializable() -> None:
|
|||||||
"id": ["tests", "unit_tests", "test_messages", "BadObject"],
|
"id": ["tests", "unit_tests", "test_messages", "BadObject"],
|
||||||
"repr": repr(bad_obj),
|
"repr": repr(bad_obj),
|
||||||
},
|
},
|
||||||
|
"status": "success",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert dumpd(message) == ser_message
|
assert dumpd(message) == ser_message
|
||||||
@ -931,6 +935,7 @@ def test_tool_message_to_dict() -> None:
|
|||||||
"name": None,
|
"name": None,
|
||||||
"id": None,
|
"id": None,
|
||||||
"tool_call_id": "1",
|
"tool_call_id": "1",
|
||||||
|
"status": "success",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
actual = message_to_dict(message)
|
actual = message_to_dict(message)
|
||||||
|
@ -127,6 +127,7 @@ def _merge_messages(
|
|||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"content": curr.content,
|
"content": curr.content,
|
||||||
"tool_use_id": curr.tool_call_id,
|
"tool_use_id": curr.tool_call_id,
|
||||||
|
"is_error": curr.status == "error",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -139,7 +139,7 @@ def test__merge_messages() -> None:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
|
ToolMessage("buz output", tool_call_id="1", status="error"), # type: ignore[misc]
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=[
|
content=[
|
||||||
{
|
{
|
||||||
@ -180,7 +180,12 @@ def test__merge_messages() -> None:
|
|||||||
),
|
),
|
||||||
HumanMessage( # type: ignore[misc]
|
HumanMessage( # type: ignore[misc]
|
||||||
[
|
[
|
||||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"content": "buz output",
|
||||||
|
"tool_use_id": "1",
|
||||||
|
"is_error": True,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"content": [
|
"content": [
|
||||||
@ -194,6 +199,7 @@ def test__merge_messages() -> None:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"tool_use_id": "2",
|
"tool_use_id": "2",
|
||||||
|
"is_error": False,
|
||||||
},
|
},
|
||||||
{"type": "text", "text": "next thing"},
|
{"type": "text", "text": "next thing"},
|
||||||
]
|
]
|
||||||
@ -215,7 +221,12 @@ def test__merge_messages() -> None:
|
|||||||
expected = [
|
expected = [
|
||||||
HumanMessage( # type: ignore[misc]
|
HumanMessage( # type: ignore[misc]
|
||||||
[
|
[
|
||||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"content": "buz output",
|
||||||
|
"tool_use_id": "1",
|
||||||
|
"is_error": False,
|
||||||
|
},
|
||||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -382,7 +393,12 @@ def test__format_messages_with_tool_calls() -> None:
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"content": "blurb",
|
||||||
|
"tool_use_id": "1",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -421,7 +437,12 @@ def test__format_messages_with_str_content_and_tool_calls() -> None:
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"content": "blurb",
|
||||||
|
"tool_use_id": "1",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -455,7 +476,12 @@ def test__format_messages_with_list_content_and_tool_calls() -> None:
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"content": "blurb",
|
||||||
|
"tool_use_id": "1",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -502,7 +528,12 @@ def test__format_messages_with_tool_use_blocks_and_tool_calls() -> None:
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "tool_result", "content": "blurb", "tool_use_id": "1"}
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"content": "blurb",
|
||||||
|
"tool_use_id": "1",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
@ -467,6 +467,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
"text": "green is a great pick! that's my sister's favorite color", # noqa: E501
|
"text": "green is a great pick! that's my sister's favorite color", # noqa: E501
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"is_error": False,
|
||||||
},
|
},
|
||||||
{"type": "text", "text": "what's my sister's favorite color"},
|
{"type": "text", "text": "what's my sister's favorite color"},
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user