diff --git a/libs/core/langchain_core/messages/ai.py b/libs/core/langchain_core/messages/ai.py index e5aec8db3f3..0f14caecf56 100644 --- a/libs/core/langchain_core/messages/ai.py +++ b/libs/core/langchain_core/messages/ai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Union from langchain_core.messages.base import ( BaseMessage, @@ -69,6 +69,37 @@ class AIMessage(BaseMessage): pass return values + def pretty_repr(self, html: bool = False) -> str: + """Return a pretty representation of the message.""" + base = super().pretty_repr(html=html) + lines = [] + + def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]: + lines = [ + f" {tc.get('name', 'Tool')} ({tc.get('id')})", + f" Call ID: {tc.get('id')}", + ] + if tc.get("error"): + lines.append(f" Error: {tc.get('error')}") + lines.append(" Args:") + args = tc.get("args") + if isinstance(args, str): + lines.append(f" {args}") + elif isinstance(args, dict): + for arg, value in args.items(): + lines.append(f" {arg}: {value}") + return lines + + if self.tool_calls: + lines.append("Tool Calls:") + for tc in self.tool_calls: + lines.extend(_format_tool_args(tc)) + if self.invalid_tool_calls: + lines.append("Invalid Tool Calls:") + for itc in self.invalid_tool_calls: + lines.extend(_format_tool_args(itc)) + return (base.strip() + "\n" + "\n".join(lines)).strip() + AIMessage.update_forward_refs()