update tests

This commit is contained in:
Chester Curme 2025-04-09 10:10:17 -04:00
parent 0354dec091
commit 35fbe24532
4 changed files with 82 additions and 67 deletions

View File

@ -253,7 +253,7 @@ def _format_data_content_block(block: dict) -> dict:
def _format_messages( def _format_messages(
messages: List[BaseMessage], messages: Sequence[BaseMessage],
) -> Tuple[Union[str, List[Dict], None], List[Dict]]: ) -> Tuple[Union[str, List[Dict], None], List[Dict]]:
"""Format messages for anthropic.""" """Format messages for anthropic."""

View File

@ -663,34 +663,6 @@ def test_pdf_document_input() -> None:
assert isinstance(result.content, str) assert isinstance(result.content, str)
assert len(result.content) > 0 assert len(result.content) > 0
# Test cache control with standard format
result = ChatAnthropic(model=IMAGE_MODEL_NAME).invoke(
[
HumanMessage(
[
{
"type": "text",
"text": "Summarize this document:",
},
{
"type": "file",
"source_type": "base64",
"mime_type": "application/pdf",
"source": data,
"metadata": {"cache_control": {"type": "ephemeral"}},
},
]
)
]
)
assert isinstance(result, AIMessage)
assert isinstance(result.content, str)
assert len(result.content) > 0
assert result.usage_metadata is not None
cache_creation = result.usage_metadata["input_token_details"]["cache_creation"]
cache_read = result.usage_metadata["input_token_details"]["cache_read"]
assert cache_creation > 0 or cache_read > 0
def test_citations() -> None: def test_citations() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest") llm = ChatAnthropic(model="claude-3-5-haiku-latest")
@ -727,27 +699,6 @@ def test_citations() -> None:
assert any("citations" in block for block in full.content) assert any("citations" in block for block in full.content)
assert not any("citation" in block for block in full.content) assert not any("citation" in block for block in full.content)
# Test standard format
messages = [
{
"role": "user",
"content": [
{
"type": "file",
"source_type": "text",
"source": "The grass is green. The sky is blue.",
"mime_type": "text/plain",
"metadata": {"citations": {"enabled": True}},
},
{"type": "text", "text": "What color is the grass and sky?"},
],
}
]
response = llm.invoke(messages)
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert any("citations" in block for block in response.content)
def test_thinking() -> None: def test_thinking() -> None:
llm = ChatAnthropic( llm = ChatAnthropic(

View File

@ -690,6 +690,85 @@ def test__format_messages_with_cache_control() -> None:
assert expected_system == actual_system assert expected_system == actual_system
assert expected_messages == actual_messages assert expected_messages == actual_messages
# Test standard multi-modal format
messages = [
HumanMessage(
[
{
"type": "text",
"text": "Summarize this document:",
},
{
"type": "file",
"source_type": "base64",
"mime_type": "application/pdf",
"source": "<base64 data>",
"metadata": {"cache_control": {"type": "ephemeral"}},
},
]
)
]
actual_system, actual_messages = _format_messages(messages)
assert actual_system is None
expected_messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Summarize this document:",
},
{
"type": "document",
"source": {
"type": "base64",
"media_type": "application/pdf",
"data": "<base64 data>",
},
"cache_control": {"type": "ephemeral"},
},
],
}
]
assert actual_messages == expected_messages
def test__format_messages_with_citations() -> None:
input_messages = [
HumanMessage(
content=[
{
"type": "file",
"source_type": "text",
"source": "The grass is green. The sky is blue.",
"mime_type": "text/plain",
"metadata": {"citations": {"enabled": True}},
},
{"type": "text", "text": "What color is the grass and sky?"},
]
)
]
expected_messages = [
{
"role": "user",
"content": [
{
"type": "document",
"source": {
"type": "text",
"media_type": "text/plain",
"data": "The grass is green. The sky is blue.",
},
"citations": {"enabled": True},
},
{"type": "text", "text": "What color is the grass and sky?"},
],
}
]
actual_system, actual_messages = _format_messages(input_messages)
assert actual_system is None
assert actual_messages == expected_messages
def test__format_messages_with_multiple_system() -> None: def test__format_messages_with_multiple_system() -> None:
messages = [ messages = [

View File

@ -68,6 +68,7 @@ from langchain_core.messages import (
ToolCall, ToolCall,
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk,
convert_image_content_block_to_image_url,
is_data_content_block, is_data_content_block,
) )
from langchain_core.messages.ai import ( from langchain_core.messages.ai import (
@ -195,23 +196,7 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
def _format_data_content_block(block: dict) -> dict: def _format_data_content_block(block: dict) -> dict:
"""Format standard data content block to format expected by OpenAI.""" """Format standard data content block to format expected by OpenAI."""
if block["type"] == "image": if block["type"] == "image":
if block["source_type"] == "url": formatted_block = convert_image_content_block_to_image_url(block) # type: ignore[arg-type]
formatted_block = {
"type": "image_url",
"image_url": {"url": block["source"]},
}
elif block["source_type"] == "base64":
formatted_block = {
"type": "image_url",
"image_url": {
"url": f"data:{block['mime_type']};base64,{block['source']}"
},
}
else:
raise ValueError(
"OpenAI only supports 'url' and 'base64' source_type for image "
"content blocks."
)
elif block["type"] == "file": elif block["type"] == "file":
if block["source_type"] == "base64": if block["source_type"] == "base64":