diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 0a81b81b3be..5e4a16c384f 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1,5 +1,6 @@ import os -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple +import re +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union import anthropic from langchain_core._api.deprecation import deprecated @@ -24,6 +25,33 @@ from langchain_core.utils import ( _message_type_lookups = {"human": "user", "ai": "assistant"} +def _format_image(image_url: str) -> Dict: + """ + Formats an image of format data:image/jpeg;base64,{b64_string} + to a dict for anthropic api + + { + "type": "base64", + "media_type": "image/jpeg", + "data": "/9j/4AAQSkZJRg...", + } + + And throws an error if it's not a b64 image + """ + regex = r"^data:(?Pimage/.+);base64,(?P.+)$" + match = re.match(regex, image_url) + if match is None: + raise ValueError( + "Anthropic only supports base64-encoded images currently." + " Example: data:image/png;base64,'/9j/4AAQSk'..." + ) + return { + "type": "base64", + "media_type": match.group("media_type"), + "data": match.group("data"), + } + + def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]: """Format messages for anthropic.""" @@ -36,22 +64,66 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D for m in messages ] """ - system = None - formatted_messages = [] + system: Optional[str] = None + formatted_messages: List[Dict] = [] for i, message in enumerate(messages): - if not isinstance(message.content, str): - raise ValueError("Anthropic Messages API only supports text generation.") if message.type == "system": if i != 0: raise ValueError("System message must be at beginning of message list.") + if not isinstance(message.content, str): + raise ValueError( + "System message must be a string, " + f"instead was: {type(message.content)}" + ) system = message.content + continue + + role = _message_type_lookups[message.type] + content: Union[str, List[Dict]] + + if not isinstance(message.content, str): + # parse as dict + assert isinstance( + message.content, list + ), "Anthropic message content must be str or list of dicts" + + # populate content + content = [] + for item in message.content: + if isinstance(item, str): + content.append( + { + "type": "text", + "text": item, + } + ) + elif isinstance(item, dict): + if "type" not in item: + raise ValueError("Dict content item must have a type key") + if item["type"] == "image_url": + # convert format + source = _format_image(item["image_url"]["url"]) + content.append( + { + "type": "image", + "source": source, + } + ) + else: + content.append(item) + else: + raise ValueError( + f"Content items must be str or dict, instead was: {type(item)}" + ) else: - formatted_messages.append( - { - "role": _message_type_lookups[message.type], - "content": message.content, - } - ) + content = message.content + + formatted_messages.append( + { + "role": role, + "content": content, + } + ) return system, formatted_messages diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index da5f376aa3b..f1050f4fea9 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -162,3 +162,25 @@ async def test_anthropic_async_streaming_callback() -> None: assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) assert callback_handler.llm_streams > 1 + + +def test_anthropic_multimodal() -> None: + """Test that multimodal inputs are handled correctly.""" + chat = ChatAnthropic(model=MODEL_NAME) + messages = [ + HumanMessage( + content=[ + { + "type": "image_url", + "image_url": { + # langchain logo + "url": "", # noqa: E501 + }, + }, + {"type": "text", "text": "What is this a logo for?"}, + ] + ) + ] + response = chat.invoke(messages) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str)