diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 87c0fe5f4ea..4a5ac5f6121 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -50,9 +50,6 @@ class ToolMessage(BaseMessage): tool_call_id: str """Tool call that this message is responding to.""" - # TODO: Add is_error param? - # is_error: bool = False - # """Whether the tool errored.""" type: Literal["tool"] = "tool" """The type of the message (used for serialization). Defaults to "tool".""" @@ -67,6 +64,12 @@ class ToolMessage(BaseMessage): .. versionadded:: 0.2.17 """ + status: Literal["success", "error"] = "success" + """Status of the tool invocation. + + .. versionadded:: 0.2.24 + """ + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object. @@ -119,6 +122,7 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk): self.response_metadata, other.response_metadata ), id=self.id, + status=_merge_status(self.status, other.status), ) return super().__add__(other) @@ -280,3 +284,9 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk] ) tool_call_chunks.append(parsed) return tool_call_chunks + + +def _merge_status( + left: Literal["success", "error"], right: Literal["success", "error"] +) -> Literal["success", "error"]: + return "error" if "error" in (left, right) else "success" diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 0b251881436..9a28a56ba48 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -625,23 +625,27 @@ class ChildTool(BaseTool): content, artifact = response else: content = response + status = "success" except ValidationError as e: if not self.handle_validation_error: error_to_raise = e else: content = _handle_validation_error(e, flag=self.handle_validation_error) + status = "error" except ToolException as e: if not self.handle_tool_error: error_to_raise = e else: content = _handle_tool_error(e, flag=self.handle_tool_error) + status = "error" except (Exception, KeyboardInterrupt) as e: error_to_raise = e + status = "error" if error_to_raise: run_manager.on_tool_error(error_to_raise) raise error_to_raise - output = _format_output(content, artifact, tool_call_id, self.name) + output = _format_output(content, artifact, tool_call_id, self.name, status) run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) return output @@ -737,24 +741,28 @@ class ChildTool(BaseTool): content, artifact = response else: content = response + status = "success" except ValidationError as e: if not self.handle_validation_error: error_to_raise = e else: content = _handle_validation_error(e, flag=self.handle_validation_error) + status = "error" except ToolException as e: if not self.handle_tool_error: error_to_raise = e else: content = _handle_tool_error(e, flag=self.handle_tool_error) + status = "error" except (Exception, KeyboardInterrupt) as e: error_to_raise = e + status = "error" if error_to_raise: await run_manager.on_tool_error(error_to_raise) raise error_to_raise - output = _format_output(content, artifact, tool_call_id, self.name) + output = _format_output(content, artifact, tool_call_id, self.name, status) await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) return output @@ -1511,7 +1519,7 @@ def _prep_run_args( def _format_output( - content: Any, artifact: Any, tool_call_id: Optional[str], name: str + content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str ) -> Union[ToolMessage, Any]: if tool_call_id: # NOTE: This will fail to stringify lists which aren't actually content blocks @@ -1524,7 +1532,11 @@ def _format_output( ): content = _stringify(content) return ToolMessage( - content, artifact=artifact, tool_call_id=tool_call_id, name=name + content, + artifact=artifact, + tool_call_id=tool_call_id, + name=name, + status=status, ) else: return content diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 4a7ae82b270..28ef4b9f0fa 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -524,6 +524,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -1132,6 +1141,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 5eb5c50b13d..75bdc180919 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -880,6 +880,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index 144c4118d92..88b421db6bb 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -5806,6 +5806,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -6483,6 +6492,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -7075,6 +7093,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -7724,6 +7751,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -8376,6 +8412,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -8975,6 +9020,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -9547,6 +9601,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', @@ -10228,6 +10291,15 @@ 'title': 'Response Metadata', 'type': 'object', }), + 'status': dict({ + 'default': 'success', + 'enum': list([ + 'success', + 'error', + ]), + 'title': 'Status', + 'type': 'string', + }), 'tool_call_id': dict({ 'title': 'Tool Call Id', 'type': 'string', diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index b62803a3b40..7fdbfc5e66b 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -874,7 +874,9 @@ def test_merge_tool_calls() -> None: def test_tool_message_serdes() -> None: - message = ToolMessage("foo", artifact={"bar": {"baz": 123}}, tool_call_id="1") + message = ToolMessage( + "foo", artifact={"bar": {"baz": 123}}, tool_call_id="1", status="error" + ) ser_message = { "lc": 1, "type": "constructor", @@ -884,6 +886,7 @@ def test_tool_message_serdes() -> None: "type": "tool", "tool_call_id": "1", "artifact": {"bar": {"baz": 123}}, + "status": "error", }, } assert dumpd(message) == ser_message @@ -911,6 +914,7 @@ def test_tool_message_ser_non_serializable() -> None: "id": ["tests", "unit_tests", "test_messages", "BadObject"], "repr": repr(bad_obj), }, + "status": "success", }, } assert dumpd(message) == ser_message @@ -931,6 +935,7 @@ def test_tool_message_to_dict() -> None: "name": None, "id": None, "tool_call_id": "1", + "status": "success", }, } actual = message_to_dict(message) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index da487a5ffec..8b21e374e3e 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -127,6 +127,7 @@ def _merge_messages( "type": "tool_result", "content": curr.content, "tool_use_id": curr.tool_call_id, + "is_error": curr.status == "error", } ] ) diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index d7e21edb9ac..187dae90d4c 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -139,7 +139,7 @@ def test__merge_messages() -> None: }, ] ), - ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] + ToolMessage("buz output", tool_call_id="1", status="error"), # type: ignore[misc] ToolMessage( content=[ { @@ -180,7 +180,12 @@ def test__merge_messages() -> None: ), HumanMessage( # type: ignore[misc] [ - {"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, + { + "type": "tool_result", + "content": "buz output", + "tool_use_id": "1", + "is_error": True, + }, { "type": "tool_result", "content": [ @@ -194,6 +199,7 @@ def test__merge_messages() -> None: }, ], "tool_use_id": "2", + "is_error": False, }, {"type": "text", "text": "next thing"}, ] @@ -215,7 +221,12 @@ def test__merge_messages() -> None: expected = [ HumanMessage( # type: ignore[misc] [ - {"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, + { + "type": "tool_result", + "content": "buz output", + "tool_use_id": "1", + "is_error": False, + }, {"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, ] ) @@ -382,7 +393,12 @@ def test__format_messages_with_tool_calls() -> None: { "role": "user", "content": [ - {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + { + "type": "tool_result", + "content": "blurb", + "tool_use_id": "1", + "is_error": False, + } ], }, ], @@ -421,7 +437,12 @@ def test__format_messages_with_str_content_and_tool_calls() -> None: { "role": "user", "content": [ - {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + { + "type": "tool_result", + "content": "blurb", + "tool_use_id": "1", + "is_error": False, + } ], }, ], @@ -455,7 +476,12 @@ def test__format_messages_with_list_content_and_tool_calls() -> None: { "role": "user", "content": [ - {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + { + "type": "tool_result", + "content": "blurb", + "tool_use_id": "1", + "is_error": False, + } ], }, ], @@ -502,7 +528,12 @@ def test__format_messages_with_tool_use_blocks_and_tool_calls() -> None: { "role": "user", "content": [ - {"type": "tool_result", "content": "blurb", "tool_use_id": "1"} + { + "type": "tool_result", + "content": "blurb", + "tool_use_id": "1", + "is_error": False, + } ], }, ], diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 3e3fadb756c..9aac22360db 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -467,6 +467,7 @@ class ChatModelIntegrationTests(ChatModelTests): "text": "green is a great pick! that's my sister's favorite color", # noqa: E501 } ], + "is_error": False, }, {"type": "text", "text": "what's my sister's favorite color"}, ]