mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-05 03:48:48 +00:00
feat(core): add XML format option for get_buffer_string (#34802)
## Summary Add XML format option for `get_buffer_string()` to provide unambiguous message serialization. This fixes role prefix ambiguity when message content contains strings like "Human:" or "AI:". Fixes #34786 ## Changes - Add `format="xml"` parameter with proper XML escaping using `quoteattr()` for attributes - Add explicit validation for format parameter (raises `ValueError` for invalid values) - Add comprehensive tests for XML format edge cases <img width="1952" height="706" alt="image" src="https://github.com/user-attachments/assets/1cd6f887-9365-43cf-a532-72d7addd8bad" /> <img width="2786" height="776" alt="image" src="https://github.com/user-attachments/assets/a07b0db0-519c-46d7-b34b-b404237d812b" /> --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
@@ -28,6 +28,7 @@ from typing import (
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from xml.sax.saxutils import escape, quoteattr
|
||||
|
||||
from pydantic import Discriminator, Field, Tag
|
||||
|
||||
@@ -98,11 +99,199 @@ AnyMessage = Annotated[
|
||||
"""A type representing any defined `Message` or `MessageChunk` type."""
|
||||
|
||||
|
||||
def _has_base64_data(block: dict) -> bool:
|
||||
"""Check if a content block contains base64 encoded data.
|
||||
|
||||
Args:
|
||||
block: A content block dictionary.
|
||||
|
||||
Returns:
|
||||
Whether the block contains base64 data.
|
||||
"""
|
||||
# Check for explicit base64 field (standard content blocks)
|
||||
if block.get("base64"):
|
||||
return True
|
||||
|
||||
# Check for data: URL in url field
|
||||
url = block.get("url", "")
|
||||
if isinstance(url, str) and url.startswith("data:"):
|
||||
return True
|
||||
|
||||
# Check for OpenAI-style image_url with data: URL
|
||||
image_url = block.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url", "")
|
||||
if isinstance(url, str) and url.startswith("data:"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
_XML_CONTENT_BLOCK_MAX_LEN = 500
|
||||
|
||||
|
||||
def _truncate(text: str, max_len: int = _XML_CONTENT_BLOCK_MAX_LEN) -> str:
|
||||
"""Truncate text to `max_len` characters, adding ellipsis if truncated."""
|
||||
if len(text) <= max_len:
|
||||
return text
|
||||
return text[:max_len] + "..."
|
||||
|
||||
|
||||
def _format_content_block_xml(block: dict) -> str | None:
|
||||
"""Format a content block as XML.
|
||||
|
||||
Args:
|
||||
block: A LangChain content block.
|
||||
|
||||
Returns:
|
||||
XML string representation of the block, or `None` if the block should be
|
||||
skipped.
|
||||
|
||||
Note:
|
||||
Plain text document content, server tool call arguments, and server tool
|
||||
result outputs are truncated to 500 characters.
|
||||
"""
|
||||
block_type = block.get("type", "")
|
||||
|
||||
# Skip blocks with base64 encoded data
|
||||
if _has_base64_data(block):
|
||||
return None
|
||||
|
||||
# Text blocks
|
||||
if block_type == "text":
|
||||
text = block.get("text", "")
|
||||
return escape(text) if text else None
|
||||
|
||||
# Reasoning blocks
|
||||
if block_type == "reasoning":
|
||||
reasoning = block.get("reasoning", "")
|
||||
if reasoning:
|
||||
return f"<reasoning>{escape(reasoning)}</reasoning>"
|
||||
return None
|
||||
|
||||
# Image blocks (URL only, base64 already filtered)
|
||||
if block_type == "image":
|
||||
url = block.get("url")
|
||||
file_id = block.get("file_id")
|
||||
if url:
|
||||
return f"<image url={quoteattr(url)} />"
|
||||
if file_id:
|
||||
return f"<image file_id={quoteattr(file_id)} />"
|
||||
return None
|
||||
|
||||
# OpenAI-style image_url blocks
|
||||
if block_type == "image_url":
|
||||
image_url = block.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url", "")
|
||||
if url and not url.startswith("data:"):
|
||||
return f"<image url={quoteattr(url)} />"
|
||||
return None
|
||||
|
||||
# Audio blocks (URL only)
|
||||
if block_type == "audio":
|
||||
url = block.get("url")
|
||||
file_id = block.get("file_id")
|
||||
if url:
|
||||
return f"<audio url={quoteattr(url)} />"
|
||||
if file_id:
|
||||
return f"<audio file_id={quoteattr(file_id)} />"
|
||||
return None
|
||||
|
||||
# Video blocks (URL only)
|
||||
if block_type == "video":
|
||||
url = block.get("url")
|
||||
file_id = block.get("file_id")
|
||||
if url:
|
||||
return f"<video url={quoteattr(url)} />"
|
||||
if file_id:
|
||||
return f"<video file_id={quoteattr(file_id)} />"
|
||||
return None
|
||||
|
||||
# Plain text document blocks
|
||||
if block_type == "text-plain":
|
||||
text = block.get("text", "")
|
||||
return escape(_truncate(text)) if text else None
|
||||
|
||||
# Server tool call blocks (from AI messages)
|
||||
if block_type == "server_tool_call":
|
||||
tc_id = quoteattr(str(block.get("id") or ""))
|
||||
tc_name = quoteattr(str(block.get("name") or ""))
|
||||
tc_args_json = json.dumps(block.get("args", {}), ensure_ascii=False)
|
||||
tc_args = escape(_truncate(tc_args_json))
|
||||
return (
|
||||
f"<server_tool_call id={tc_id} name={tc_name}>{tc_args}</server_tool_call>"
|
||||
)
|
||||
|
||||
# Server tool result blocks
|
||||
if block_type == "server_tool_result":
|
||||
tool_call_id = quoteattr(str(block.get("tool_call_id") or ""))
|
||||
status = quoteattr(str(block.get("status") or ""))
|
||||
output = block.get("output")
|
||||
if output:
|
||||
output_json = json.dumps(output, ensure_ascii=False)
|
||||
output_str = escape(_truncate(output_json))
|
||||
else:
|
||||
output_str = ""
|
||||
return (
|
||||
f"<server_tool_result tool_call_id={tool_call_id} status={status}>"
|
||||
f"{output_str}</server_tool_result>"
|
||||
)
|
||||
|
||||
# Unknown block type - skip silently
|
||||
return None
|
||||
|
||||
|
||||
def _get_message_type_str(
|
||||
m: BaseMessage,
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
system_prefix: str,
|
||||
function_prefix: str,
|
||||
tool_prefix: str,
|
||||
) -> str:
|
||||
"""Get the type string for XML message element.
|
||||
|
||||
Args:
|
||||
m: The message to get the type string for.
|
||||
human_prefix: The prefix to use for `HumanMessage`.
|
||||
ai_prefix: The prefix to use for `AIMessage`.
|
||||
system_prefix: The prefix to use for `SystemMessage`.
|
||||
function_prefix: The prefix to use for `FunctionMessage`.
|
||||
tool_prefix: The prefix to use for `ToolMessage`.
|
||||
|
||||
Returns:
|
||||
The type string for the message element.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported message type is encountered.
|
||||
"""
|
||||
if isinstance(m, HumanMessage):
|
||||
return human_prefix.lower()
|
||||
if isinstance(m, AIMessage):
|
||||
return ai_prefix.lower()
|
||||
if isinstance(m, SystemMessage):
|
||||
return system_prefix.lower()
|
||||
if isinstance(m, FunctionMessage):
|
||||
return function_prefix.lower()
|
||||
if isinstance(m, ToolMessage):
|
||||
return tool_prefix.lower()
|
||||
if isinstance(m, ChatMessage):
|
||||
return m.role
|
||||
msg = f"Got unsupported message type: {m}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: Sequence[BaseMessage],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
*,
|
||||
system_prefix: str = "System",
|
||||
function_prefix: str = "Function",
|
||||
tool_prefix: str = "Tool",
|
||||
message_separator: str = "\n",
|
||||
format: Literal["prefix", "xml"] = "prefix", # noqa: A002
|
||||
) -> str:
|
||||
r"""Convert a sequence of messages to strings and concatenate them into one string.
|
||||
|
||||
@@ -110,7 +299,15 @@ def get_buffer_string(
|
||||
messages: Messages to be converted to strings.
|
||||
human_prefix: The prefix to prepend to contents of `HumanMessage`s.
|
||||
ai_prefix: The prefix to prepend to contents of `AIMessage`.
|
||||
system_prefix: The prefix to prepend to contents of `SystemMessage`s.
|
||||
function_prefix: The prefix to prepend to contents of `FunctionMessage`s.
|
||||
tool_prefix: The prefix to prepend to contents of `ToolMessage`s.
|
||||
message_separator: The separator to use between messages.
|
||||
format: The output format. `'prefix'` uses `Role: content` format (default).
|
||||
|
||||
`'xml'` uses XML-style `<message type='role'>` format with proper character
|
||||
escaping, which is useful when message content may contain role-like
|
||||
prefixes that could cause ambiguity.
|
||||
|
||||
Returns:
|
||||
A single string concatenation of all input messages.
|
||||
@@ -123,9 +320,33 @@ def get_buffer_string(
|
||||
and a function call under `additional_kwargs["function_call"]`, only the tool
|
||||
calls will be appended to the string representation.
|
||||
|
||||
When using `format='xml'`:
|
||||
|
||||
- All messages use uniform `<message type="role">content</message>` format.
|
||||
- The `type` attribute uses `human_prefix` (lowercased) for `HumanMessage`,
|
||||
`ai_prefix` (lowercased) for `AIMessage`, `system_prefix` (lowercased)
|
||||
for `SystemMessage`, `function_prefix` (lowercased) for `FunctionMessage`,
|
||||
`tool_prefix` (lowercased) for `ToolMessage`, and the original role
|
||||
(unchanged) for `ChatMessage`.
|
||||
- Message content is escaped using `xml.sax.saxutils.escape()`.
|
||||
- Attribute values are escaped using `xml.sax.saxutils.quoteattr()`.
|
||||
- AI messages with tool calls use nested structure with `<content>` and
|
||||
`<tool_call>` elements.
|
||||
- For multi-modal content (list of content blocks), supported block types
|
||||
are: `text`, `reasoning`, `image` (URL/file_id only), `image_url`
|
||||
(OpenAI-style, URL only), `audio` (URL/file_id only), `video` (URL/file_id
|
||||
only), `text-plain`, `server_tool_call`, and `server_tool_result`.
|
||||
- Content blocks with base64-encoded data are skipped (including blocks
|
||||
with `base64` field or `data:` URLs).
|
||||
- Unknown block types are skipped.
|
||||
- Plain text document content (`text-plain`), server tool call arguments,
|
||||
and server tool result outputs are truncated to 500 characters.
|
||||
|
||||
Example:
|
||||
Default prefix format:
|
||||
|
||||
```python
|
||||
from langchain_core import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hi, how are you?"),
|
||||
@@ -134,7 +355,54 @@ def get_buffer_string(
|
||||
get_buffer_string(messages)
|
||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||
```
|
||||
|
||||
XML format (useful when content contains role-like prefixes):
|
||||
|
||||
```python
|
||||
messages = [
|
||||
HumanMessage(content="Example: Human: some text"),
|
||||
AIMessage(content="I see the example."),
|
||||
]
|
||||
get_buffer_string(messages, format="xml")
|
||||
# -> '<message type="human">Example: Human: some text</message>\\n'
|
||||
# -> '<message type="ai">I see the example.</message>'
|
||||
```
|
||||
|
||||
XML format with special characters (automatically escaped):
|
||||
|
||||
```python
|
||||
messages = [
|
||||
HumanMessage(content="Is 5 < 10 & 10 > 5?"),
|
||||
]
|
||||
get_buffer_string(messages, format="xml")
|
||||
# -> '<message type="human">Is 5 < 10 & 10 > 5?</message>'
|
||||
```
|
||||
|
||||
XML format with tool calls:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="I'll search for that.",
|
||||
tool_calls=[
|
||||
{"id": "call_123", "name": "search", "args": {"query": "weather"}}
|
||||
],
|
||||
),
|
||||
]
|
||||
get_buffer_string(messages, format="xml")
|
||||
# -> '<message type="ai">\\n'
|
||||
# -> ' <content>I\\'ll search for that.</content>\\n'
|
||||
# -> ' <tool_call id="call_123" name="search">'
|
||||
# -> '{"query": "weather"}</tool_call>\\n'
|
||||
# -> '</message>'
|
||||
```
|
||||
"""
|
||||
if format not in ("prefix", "xml"):
|
||||
msg = (
|
||||
f"Unrecognized format={format!r}. Supported formats are 'prefix' and 'xml'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
@@ -142,25 +410,92 @@ def get_buffer_string(
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
role = system_prefix
|
||||
elif isinstance(m, FunctionMessage):
|
||||
role = "Function"
|
||||
role = function_prefix
|
||||
elif isinstance(m, ToolMessage):
|
||||
role = "Tool"
|
||||
role = tool_prefix
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
msg = f"Got unsupported message type: {m}"
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
||||
message = f"{role}: {m.text}"
|
||||
if format == "xml":
|
||||
msg_type = _get_message_type_str(
|
||||
m, human_prefix, ai_prefix, system_prefix, function_prefix, tool_prefix
|
||||
)
|
||||
|
||||
if isinstance(m, AIMessage):
|
||||
if m.tool_calls:
|
||||
message += f"{m.tool_calls}"
|
||||
elif "function_call" in m.additional_kwargs:
|
||||
# Legacy behavior assumes only one function call per message
|
||||
message += f"{m.additional_kwargs['function_call']}"
|
||||
# Format content blocks
|
||||
if isinstance(m.content, str):
|
||||
content_parts = [escape(m.content)] if m.content else []
|
||||
else:
|
||||
# List of content blocks
|
||||
content_parts = []
|
||||
for block in m.content:
|
||||
if isinstance(block, str):
|
||||
if block:
|
||||
content_parts.append(escape(block))
|
||||
else:
|
||||
formatted = _format_content_block_xml(block)
|
||||
if formatted:
|
||||
content_parts.append(formatted)
|
||||
|
||||
# Check if this is an AIMessage with tool calls
|
||||
has_tool_calls = isinstance(m, AIMessage) and m.tool_calls
|
||||
has_function_call = (
|
||||
isinstance(m, AIMessage)
|
||||
and not m.tool_calls
|
||||
and "function_call" in m.additional_kwargs
|
||||
)
|
||||
|
||||
if has_tool_calls or has_function_call:
|
||||
# Use nested structure for AI messages with tool calls
|
||||
# Type narrowing: at this point m is AIMessage (verified above)
|
||||
ai_msg = cast("AIMessage", m)
|
||||
parts = [f"<message type={quoteattr(msg_type)}>"]
|
||||
if content_parts:
|
||||
parts.append(f" <content>{' '.join(content_parts)}</content>")
|
||||
|
||||
if has_tool_calls:
|
||||
for tc in ai_msg.tool_calls:
|
||||
tc_id = quoteattr(str(tc.get("id") or ""))
|
||||
tc_name = quoteattr(str(tc.get("name") or ""))
|
||||
tc_args = escape(
|
||||
json.dumps(tc.get("args", {}), ensure_ascii=False)
|
||||
)
|
||||
parts.append(
|
||||
f" <tool_call id={tc_id} name={tc_name}>"
|
||||
f"{tc_args}</tool_call>"
|
||||
)
|
||||
elif has_function_call:
|
||||
fc = ai_msg.additional_kwargs["function_call"]
|
||||
fc_name = quoteattr(str(fc.get("name") or ""))
|
||||
fc_args = escape(str(fc.get("arguments") or "{}"))
|
||||
parts.append(
|
||||
f" <function_call name={fc_name}>{fc_args}</function_call>"
|
||||
)
|
||||
|
||||
parts.append("</message>")
|
||||
message = "\n".join(parts)
|
||||
else:
|
||||
# Simple structure for messages without tool calls
|
||||
joined_content = " ".join(content_parts)
|
||||
message = (
|
||||
f"<message type={quoteattr(msg_type)}>{joined_content}</message>"
|
||||
)
|
||||
else: # format == "prefix"
|
||||
content = m.text
|
||||
message = f"{role}: {content}"
|
||||
tool_info = ""
|
||||
if isinstance(m, AIMessage):
|
||||
if m.tool_calls:
|
||||
tool_info = str(m.tool_calls)
|
||||
elif "function_call" in m.additional_kwargs:
|
||||
# Legacy behavior assumes only one function call per message
|
||||
tool_info = str(m.additional_kwargs["function_call"])
|
||||
if tool_info:
|
||||
message += tool_info # Preserve original behavior
|
||||
|
||||
string_messages.append(message)
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
@@ -1778,3 +1780,884 @@ def test_convert_to_openai_messages_reasoning_content() -> None:
|
||||
],
|
||||
}
|
||||
assert mixed_result == expected_mixed
|
||||
|
||||
|
||||
# Tests for get_buffer_string XML format
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_empty_messages_list() -> None:
|
||||
"""Test XML format with empty messages list."""
|
||||
messages: list[BaseMessage] = []
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
expected = ""
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_basic() -> None:
|
||||
"""Test XML format output with all message types."""
|
||||
messages = [
|
||||
SystemMessage(content="System message"),
|
||||
HumanMessage(content="Human message"),
|
||||
AIMessage(content="AI message"),
|
||||
FunctionMessage(content="Function result", name="test_fn"),
|
||||
ToolMessage(content="Tool result", tool_call_id="123"),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
expected = (
|
||||
'<message type="system">System message</message>\n'
|
||||
'<message type="human">Human message</message>\n'
|
||||
'<message type="ai">AI message</message>\n'
|
||||
'<message type="function">Function result</message>\n'
|
||||
'<message type="tool">Tool result</message>'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_custom_prefixes() -> None:
|
||||
"""Test XML format with custom human and ai prefixes."""
|
||||
messages = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there"),
|
||||
]
|
||||
result = get_buffer_string(
|
||||
messages, human_prefix="User", ai_prefix="Assistant", format="xml"
|
||||
)
|
||||
expected = (
|
||||
'<message type="user">Hello</message>\n'
|
||||
'<message type="assistant">Hi there</message>'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_custom_separator() -> None:
|
||||
"""Test XML format with custom message separator."""
|
||||
messages = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there"),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml", message_separator="\n\n")
|
||||
expected = (
|
||||
'<message type="human">Hello</message>\n\n<message type="ai">Hi there</message>'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_prefix_custom_separator() -> None:
|
||||
"""Test prefix format with custom message separator."""
|
||||
messages = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there"),
|
||||
]
|
||||
result = get_buffer_string(messages, format="prefix", message_separator=" | ")
|
||||
expected = "Human: Hello | AI: Hi there"
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_escaping() -> None:
|
||||
"""Test XML format properly escapes special characters in content."""
|
||||
messages = [
|
||||
HumanMessage(content="Is 5 < 10 & 10 > 5?"),
|
||||
AIMessage(content='Yes, and here\'s a "quote"'),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# xml.sax.saxutils.escape escapes <, >, & (not quotes in content)
|
||||
expected = (
|
||||
'<message type="human">Is 5 < 10 & 10 > 5?</message>\n'
|
||||
'<message type="ai">Yes, and here\'s a "quote"</message>'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_unicode_content() -> None:
|
||||
"""Test XML format with Unicode content."""
|
||||
messages = [
|
||||
HumanMessage(content="你好世界"), # Chinese: Hello World
|
||||
AIMessage(content="こんにちは"), # Japanese: Hello
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
expected = (
|
||||
'<message type="human">你好世界</message>\n'
|
||||
'<message type="ai">こんにちは</message>'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_chat_message_valid_role() -> None:
|
||||
"""Test XML format with `ChatMessage` having valid XML tag name role."""
|
||||
messages = [
|
||||
ChatMessage(content="Hello", role="Assistant"),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Role is used directly as the type attribute value
|
||||
expected = '<message type="Assistant">Hello</message>'
|
||||
assert result == expected
|
||||
|
||||
# Spaces in role
|
||||
messages = [
|
||||
ChatMessage(content="Hello", role="my custom role"),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Custom roles with spaces use quoteattr for proper escaping
|
||||
expected = '<message type="my custom role">Hello</message>'
|
||||
assert result == expected
|
||||
|
||||
# Special characters in role
|
||||
messages = [
|
||||
ChatMessage(content="Hello", role='role"with<special>'),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# quoteattr handles escaping of special characters in attribute values
|
||||
# Note: quoteattr uses single quotes when the string contains double quotes
|
||||
expected = """<message type='role"with<special>'>Hello</message>"""
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_empty_content() -> None:
|
||||
"""Test XML format with empty content."""
|
||||
messages = [
|
||||
HumanMessage(content=""),
|
||||
AIMessage(content=""),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
expected = '<message type="human"></message>\n<message type="ai"></message>'
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_tool_calls_with_content() -> None:
|
||||
"""Test XML format with `AIMessage` having both `content` and `tool_calls`."""
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="Let me check that",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"args": {"city": "NYC"},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Nested structure with content and tool_call elements
|
||||
expected = (
|
||||
'<message type="ai">\n'
|
||||
" <content>Let me check that</content>\n"
|
||||
' <tool_call id="call_1" name="get_weather">{"city": "NYC"}</tool_call>\n'
|
||||
"</message>"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_tool_calls_empty_content() -> None:
|
||||
"""Test XML format with `AIMessage` having empty `content` and `tool_calls`."""
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search",
|
||||
"args": {"query": "test"},
|
||||
"id": "call_2",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# No content element when content is empty
|
||||
expected = (
|
||||
'<message type="ai">\n'
|
||||
' <tool_call id="call_2" name="search">{"query": "test"}</tool_call>\n'
|
||||
"</message>"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_tool_calls_escaping() -> None:
|
||||
"""Test XML format escapes special characters in tool calls."""
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "calculate",
|
||||
"args": {"expression": "5 < 10 & 10 > 5"},
|
||||
"id": "call_3",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Special characters in tool_calls args should be escaped
|
||||
assert "<" in result
|
||||
assert ">" in result
|
||||
assert "&" in result
|
||||
# Verify overall structure
|
||||
assert result.startswith('<message type="ai">')
|
||||
assert result.endswith("</message>")
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_function_call_legacy() -> None:
|
||||
"""Test XML format with legacy `function_call` in `additional_kwargs`."""
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="Calling function",
|
||||
additional_kwargs={
|
||||
"function_call": {"name": "test_fn", "arguments": '{"x": 1}'}
|
||||
},
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Nested structure with function_call element
|
||||
# Note: arguments is a string, so quotes inside are escaped
|
||||
expected = (
|
||||
'<message type="ai">\n'
|
||||
" <content>Calling function</content>\n"
|
||||
' <function_call name="test_fn">{"x": 1}</function_call>\n'
|
||||
"</message>"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_structured_content() -> None:
|
||||
"""Test XML format with structured content (list content blocks)."""
|
||||
messages = [
|
||||
HumanMessage(content=[{"type": "text", "text": "Hello, world!"}]),
|
||||
AIMessage(content=[{"type": "text", "text": "Hi there!"}]),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# message.text property should extract text from structured content
|
||||
expected = (
|
||||
'<message type="human">Hello, world!</message>\n'
|
||||
'<message type="ai">Hi there!</message>'
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_multiline_content() -> None:
|
||||
"""Test XML format with multiline content."""
|
||||
messages = [
|
||||
HumanMessage(content="Line 1\nLine 2\nLine 3"),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
expected = '<message type="human">Line 1\nLine 2\nLine 3</message>'
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_tool_calls_preferred_over_function_call() -> None:
|
||||
"""Test that `tool_calls` takes precedence over legacy `function_call` in XML."""
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="Calling tools",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "modern_tool",
|
||||
"args": {"key": "value"},
|
||||
"id": "call_3",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
additional_kwargs={
|
||||
"function_call": {"name": "legacy_function", "arguments": "{}"}
|
||||
},
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "modern_tool" in result
|
||||
assert "legacy_function" not in result
|
||||
# Should use tool_call element, not function_call
|
||||
assert "<tool_call" in result
|
||||
assert "<function_call" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_multiple_tool_calls() -> None:
|
||||
"""Test XML format with `AIMessage` having multiple `tool_calls`."""
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="I'll help with that",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"args": {"city": "NYC"},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "get_time",
|
||||
"args": {"timezone": "EST"},
|
||||
"id": "call_2",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Should have nested structure with multiple tool_call elements
|
||||
expected = (
|
||||
'<message type="ai">\n'
|
||||
" <content>I'll help with that</content>\n"
|
||||
' <tool_call id="call_1" name="get_weather">{"city": "NYC"}</tool_call>\n'
|
||||
' <tool_call id="call_2" name="get_time">{"timezone": "EST"}</tool_call>\n'
|
||||
"</message>"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_tool_call_special_chars_in_attrs() -> None:
|
||||
"""Test that tool call attributes with quotes are properly escaped."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": 'search"with"quotes',
|
||||
"args": {"query": "test"},
|
||||
"id": 'call"id',
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# quoteattr uses single quotes when value contains double quotes
|
||||
assert "name='search\"with\"quotes'" in result
|
||||
assert "id='call\"id'" in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_tool_call_none_id() -> None:
|
||||
"""Test that tool calls with `None` id are handled correctly."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "search",
|
||||
"args": {},
|
||||
"id": None,
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Should handle None by converting to empty string
|
||||
assert 'id=""' in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_function_call_special_chars_in_name() -> None:
|
||||
"""Test that `function_call` name with quotes is properly escaped."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": 'func"name',
|
||||
"arguments": "{}",
|
||||
}
|
||||
},
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# quoteattr uses single quotes when value contains double quotes
|
||||
assert "name='func\"name'" in result
|
||||
|
||||
|
||||
def test_get_buffer_string_invalid_format() -> None:
|
||||
"""Test that invalid format values raise `ValueError`."""
|
||||
messages: list[BaseMessage] = [HumanMessage(content="Hello")]
|
||||
with pytest.raises(ValueError, match="Unrecognized format"):
|
||||
get_buffer_string(messages, format="xm") # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="Unrecognized format"):
|
||||
get_buffer_string(messages, format="invalid") # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="Unrecognized format"):
|
||||
get_buffer_string(messages, format="") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_image_url_block() -> None:
|
||||
"""Test XML format with image content block containing URL."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "What is in this image?"},
|
||||
{"type": "image", "url": "https://example.com/image.png"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert '<message type="human">' in result
|
||||
assert "What is in this image?" in result
|
||||
assert '<image url="https://example.com/image.png" />' in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_image_file_id_block() -> None:
|
||||
"""Test XML format with image content block containing `file_id`."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Describe this:"},
|
||||
{"type": "image", "file_id": "file-abc123"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert '<image file_id="file-abc123" />' in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_image_base64_skipped() -> None:
|
||||
"""Test XML format skips image blocks with base64 data."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "What is this?"},
|
||||
{"type": "image", "base64": "iVBORw0KGgo...", "mime_type": "image/png"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "What is this?" in result
|
||||
assert "base64" not in result
|
||||
assert "iVBORw0KGgo" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_image_data_url_skipped() -> None:
|
||||
"""Test XML format skips image blocks with data: URLs."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Check this:"},
|
||||
{"type": "image", "url": "data:image/png;base64,iVBORw0KGgo..."},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Check this:" in result
|
||||
assert "data:image" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_openai_image_url_block() -> None:
|
||||
"""Test XML format with OpenAI-style `image_url` block."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Analyze this:"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/photo.jpg"},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Analyze this:" in result
|
||||
assert '<image url="https://example.com/photo.jpg" />' in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_openai_image_url_data_skipped() -> None:
|
||||
"""Test XML format skips OpenAI-style `image_url` blocks with data: URLs."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "See this:"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "See this:" in result
|
||||
assert "data:image" not in result
|
||||
assert "/9j/4AAQ" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_audio_url_block() -> None:
|
||||
"""Test XML format with audio content block containing URL."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Transcribe this:"},
|
||||
{"type": "audio", "url": "https://example.com/audio.mp3"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Transcribe this:" in result
|
||||
assert '<audio url="https://example.com/audio.mp3" />' in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_audio_base64_skipped() -> None:
|
||||
"""Test XML format skips audio blocks with base64 data."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Listen:"},
|
||||
{"type": "audio", "base64": "UklGRi...", "mime_type": "audio/wav"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Listen:" in result
|
||||
assert "UklGRi" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_video_url_block() -> None:
|
||||
"""Test XML format with video content block containing URL."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Describe this video:"},
|
||||
{"type": "video", "url": "https://example.com/video.mp4"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Describe this video:" in result
|
||||
assert '<video url="https://example.com/video.mp4" />' in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_video_base64_skipped() -> None:
|
||||
"""Test XML format skips video blocks with base64 data."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Watch:"},
|
||||
{"type": "video", "base64": "AAAAFGZ0eXA...", "mime_type": "video/mp4"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Watch:" in result
|
||||
assert "AAAAFGZ0eXA" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_reasoning_block() -> None:
|
||||
"""Test XML format with reasoning content block."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{"type": "reasoning", "reasoning": "Let me think about this..."},
|
||||
{"type": "text", "text": "The answer is 42."},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "<reasoning>Let me think about this...</reasoning>" in result
|
||||
assert "The answer is 42." in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_text_plain_block() -> None:
|
||||
"""Test XML format with text-plain content block."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Here is a document:"},
|
||||
{
|
||||
"type": "text-plain",
|
||||
"text": "Document content here.",
|
||||
"mime_type": "text/plain",
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Here is a document:" in result
|
||||
assert "Document content here." in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_server_tool_call_block() -> None:
|
||||
"""Test XML format with server_tool_call content block."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Let me search for that."},
|
||||
{
|
||||
"type": "server_tool_call",
|
||||
"id": "call_123",
|
||||
"name": "web_search",
|
||||
"args": {"query": "weather today"},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Let me search for that." in result
|
||||
assert '<server_tool_call id="call_123" name="web_search">' in result
|
||||
assert '{"query": "weather today"}' in result
|
||||
assert "</server_tool_call>" in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_server_tool_result_block() -> None:
|
||||
"""Test XML format with server_tool_result content block."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "server_tool_result",
|
||||
"tool_call_id": "call_123",
|
||||
"status": "success",
|
||||
"output": {"temperature": 72, "conditions": "sunny"},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert '<server_tool_result tool_call_id="call_123" status="success">' in result
|
||||
assert '"temperature": 72' in result
|
||||
assert "</server_tool_result>" in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_unknown_block_type_skipped() -> None:
|
||||
"""Test XML format silently skips unknown block types."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "unknown_type", "data": "some data"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Hello" in result
|
||||
assert "unknown_type" not in result
|
||||
assert "some data" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_mixed_content_blocks() -> None:
|
||||
"""Test XML format with multiple different content block types."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Look at this image and document:"},
|
||||
{"type": "image", "url": "https://example.com/img.png"},
|
||||
{
|
||||
"type": "text-plain",
|
||||
"text": "Doc content",
|
||||
"mime_type": "text/plain",
|
||||
},
|
||||
# This should be skipped (base64)
|
||||
{"type": "image", "base64": "abc123", "mime_type": "image/png"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Look at this image and document:" in result
|
||||
assert '<image url="https://example.com/img.png" />' in result
|
||||
assert "Doc content" in result
|
||||
assert "abc123" not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_escaping_in_content_blocks() -> None:
|
||||
"""Test that special XML characters are escaped in content blocks."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Is 5 < 10 & 10 > 5?"},
|
||||
{"type": "reasoning", "reasoning": "Let's check: <value> & </value>"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "Is 5 < 10 & 10 > 5?" in result
|
||||
assert "<value> & </value>" in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_url_with_special_chars() -> None:
|
||||
"""Test that URLs with special characters are properly quoted."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "image", "url": "https://example.com/img?a=1&b=2"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# quoteattr should handle the & in the URL
|
||||
assert "https://example.com/img?a=1&b=2" in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_text_plain_truncation() -> None:
|
||||
"""Test that text-plain content is truncated to 500 chars."""
|
||||
long_text = "x" * 600
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text-plain", "text": long_text, "mime_type": "text/plain"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
# Should be truncated to 500 chars + "..."
|
||||
assert "x" * 500 + "..." in result
|
||||
assert "x" * 501 not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_server_tool_call_args_truncation() -> None:
|
||||
"""Test that server_tool_call args are truncated to 500 chars."""
|
||||
long_value = "y" * 600
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "server_tool_call",
|
||||
"id": "call_1",
|
||||
"name": "test_tool",
|
||||
"args": {"data": long_value},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "..." in result
|
||||
# The full 600-char value should not appear
|
||||
assert long_value not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_server_tool_result_output_truncation() -> None:
|
||||
"""Test that server_tool_result output is truncated to 500 chars."""
|
||||
long_output = "z" * 600
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "server_tool_result",
|
||||
"tool_call_id": "call_1",
|
||||
"status": "success",
|
||||
"output": {"result": long_output},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert "..." in result
|
||||
# The full 600-char value should not appear
|
||||
assert long_output not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_no_truncation_under_limit() -> None:
|
||||
"""Test that content under 500 chars is not truncated."""
|
||||
short_text = "a" * 400
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text-plain", "text": short_text, "mime_type": "text/plain"},
|
||||
]
|
||||
),
|
||||
]
|
||||
result = get_buffer_string(messages, format="xml")
|
||||
assert short_text in result
|
||||
assert "..." not in result
|
||||
|
||||
|
||||
def test_get_buffer_string_custom_system_prefix() -> None:
|
||||
"""Test `get_buffer_string` with custom `system_prefix`."""
|
||||
messages: list[BaseMessage] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
]
|
||||
result = get_buffer_string(messages, system_prefix="Instructions")
|
||||
assert result == "Instructions: You are a helpful assistant.\nHuman: Hello"
|
||||
|
||||
|
||||
def test_get_buffer_string_custom_function_prefix() -> None:
|
||||
"""Test `get_buffer_string` with custom `function_prefix`."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(content="Call a function"),
|
||||
FunctionMessage(name="test_func", content="Function result"),
|
||||
]
|
||||
result = get_buffer_string(messages, function_prefix="Func")
|
||||
assert result == "Human: Call a function\nFunc: Function result"
|
||||
|
||||
|
||||
def test_get_buffer_string_custom_tool_prefix() -> None:
|
||||
"""Test `get_buffer_string` with custom `tool_prefix`."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage(content="Use a tool"),
|
||||
ToolMessage(tool_call_id="call_123", content="Tool result"),
|
||||
]
|
||||
result = get_buffer_string(messages, tool_prefix="ToolResult")
|
||||
assert result == "Human: Use a tool\nToolResult: Tool result"
|
||||
|
||||
|
||||
def test_get_buffer_string_all_custom_prefixes() -> None:
|
||||
"""Test `get_buffer_string` with all custom prefixes."""
|
||||
messages: list[BaseMessage] = [
|
||||
SystemMessage(content="System says hello"),
|
||||
HumanMessage(content="Human says hello"),
|
||||
AIMessage(content="AI says hello"),
|
||||
FunctionMessage(name="func", content="Function says hello"),
|
||||
ToolMessage(tool_call_id="call_1", content="Tool says hello"),
|
||||
]
|
||||
result = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="User",
|
||||
ai_prefix="Assistant",
|
||||
system_prefix="Sys",
|
||||
function_prefix="Fn",
|
||||
tool_prefix="T",
|
||||
)
|
||||
expected = (
|
||||
"Sys: System says hello\n"
|
||||
"User: Human says hello\n"
|
||||
"Assistant: AI says hello\n"
|
||||
"Fn: Function says hello\n"
|
||||
"T: Tool says hello"
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_custom_system_prefix() -> None:
|
||||
"""Test `get_buffer_string` XML format with custom `system_prefix`."""
|
||||
messages: list[BaseMessage] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
]
|
||||
result = get_buffer_string(messages, system_prefix="Instructions", format="xml")
|
||||
assert (
|
||||
result == '<message type="instructions">You are a helpful assistant.</message>'
|
||||
)
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_custom_function_prefix() -> None:
|
||||
"""Test `get_buffer_string` XML format with custom `function_prefix`."""
|
||||
messages: list[BaseMessage] = [
|
||||
FunctionMessage(name="test_func", content="Function result"),
|
||||
]
|
||||
result = get_buffer_string(messages, function_prefix="Fn", format="xml")
|
||||
assert result == '<message type="fn">Function result</message>'
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_custom_tool_prefix() -> None:
|
||||
"""Test `get_buffer_string` XML format with custom `tool_prefix`."""
|
||||
messages: list[BaseMessage] = [
|
||||
ToolMessage(tool_call_id="call_123", content="Tool result"),
|
||||
]
|
||||
result = get_buffer_string(messages, tool_prefix="ToolOutput", format="xml")
|
||||
assert result == '<message type="tooloutput">Tool result</message>'
|
||||
|
||||
|
||||
def test_get_buffer_string_xml_all_custom_prefixes() -> None:
|
||||
"""Test `get_buffer_string` XML format with all custom prefixes."""
|
||||
messages: list[BaseMessage] = [
|
||||
SystemMessage(content="System message"),
|
||||
HumanMessage(content="Human message"),
|
||||
AIMessage(content="AI message"),
|
||||
FunctionMessage(name="func", content="Function message"),
|
||||
ToolMessage(tool_call_id="call_1", content="Tool message"),
|
||||
]
|
||||
result = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="User",
|
||||
ai_prefix="Assistant",
|
||||
system_prefix="Sys",
|
||||
function_prefix="Fn",
|
||||
tool_prefix="T",
|
||||
format="xml",
|
||||
)
|
||||
# The messages are processed in order, not by type
|
||||
assert '<message type="sys">System message</message>' in result
|
||||
assert '<message type="user">Human message</message>' in result
|
||||
assert '<message type="assistant">AI message</message>' in result
|
||||
assert '<message type="fn">Function message</message>' in result
|
||||
assert '<message type="t">Tool message</message>' in result
|
||||
|
||||
Reference in New Issue
Block a user