anthropic[patch]: multimodal (#18517)

- anthropic[minor]: claude 3
- x
- x

---------

Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
Erick Friis 2024-03-04 17:50:13 -08:00 committed by GitHub
parent 343438e872
commit 25c7d52140
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 11 deletions

View File

@ -1,5 +1,6 @@
import os import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple import re
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import anthropic import anthropic
from langchain_core._api.deprecation import deprecated from langchain_core._api.deprecation import deprecated
@ -24,6 +25,33 @@ from langchain_core.utils import (
_message_type_lookups = {"human": "user", "ai": "assistant"} _message_type_lookups = {"human": "user", "ai": "assistant"}
def _format_image(image_url: str) -> Dict:
"""
Formats an image of format data:image/jpeg;base64,{b64_string}
to a dict for anthropic api
{
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRg...",
}
And throws an error if it's not a b64 image
"""
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
match = re.match(regex, image_url)
if match is None:
raise ValueError(
"Anthropic only supports base64-encoded images currently."
" Example: data:image/png;base64,'/9j/4AAQSk'..."
)
return {
"type": "base64",
"media_type": match.group("media_type"),
"data": match.group("data"),
}
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]: def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic.""" """Format messages for anthropic."""
@ -36,20 +64,64 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
for m in messages for m in messages
] ]
""" """
system = None system: Optional[str] = None
formatted_messages = [] formatted_messages: List[Dict] = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
if not isinstance(message.content, str):
raise ValueError("Anthropic Messages API only supports text generation.")
if message.type == "system": if message.type == "system":
if i != 0: if i != 0:
raise ValueError("System message must be at beginning of message list.") raise ValueError("System message must be at beginning of message list.")
if not isinstance(message.content, str):
raise ValueError(
"System message must be a string, "
f"instead was: {type(message.content)}"
)
system = message.content system = message.content
continue
role = _message_type_lookups[message.type]
content: Union[str, List[Dict]]
if not isinstance(message.content, str):
# parse as dict
assert isinstance(
message.content, list
), "Anthropic message content must be str or list of dicts"
# populate content
content = []
for item in message.content:
if isinstance(item, str):
content.append(
{
"type": "text",
"text": item,
}
)
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
if item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
{
"type": "image",
"source": source,
}
)
else: else:
content.append(item)
else:
raise ValueError(
f"Content items must be str or dict, instead was: {type(item)}"
)
else:
content = message.content
formatted_messages.append( formatted_messages.append(
{ {
"role": _message_type_lookups[message.type], "role": role,
"content": message.content, "content": content,
} }
) )
return system, formatted_messages return system, formatted_messages

File diff suppressed because one or more lines are too long