From 161f6af9ce1590aac5640b44a744365ed21feeab Mon Sep 17 00:00:00 2001 From: vbarda Date: Tue, 12 Nov 2024 15:04:47 -0500 Subject: [PATCH] [rfc] core: make tool_call_id optional --- libs/core/langchain_core/messages/tool.py | 4 ++-- libs/core/langchain_core/tools/base.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 653dd838f86..b674665dd54 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -50,7 +50,7 @@ class ToolMessage(BaseMessage): to request multiple tool calls in parallel. """ # noqa: E501 - tool_call_id: str + tool_call_id: Optional[str] = None """Tool call that this message is responding to.""" type: Literal["tool"] = "tool" @@ -119,7 +119,7 @@ class ToolMessage(BaseMessage): else: pass - tool_call_id = values["tool_call_id"] + tool_call_id = values.get("tool_call_id") if isinstance(tool_call_id, (UUID, int, float)): values["tool_call_id"] = str(tool_call_id) return values diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 9782234dfb1..e854b668cde 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -892,6 +892,9 @@ def _format_output( content: Any, artifact: Any, tool_call_id: Optional[str], name: str, status: str ) -> Union[ToolMessage, Any]: if tool_call_id: + if isinstance(content, ToolMessage): + content.tool_call_id = tool_call_id + return content if not _is_message_content_type(content): content = _stringify(content) return ToolMessage(