mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 22:42:05 +00:00
Support for claude v3 models. (#18630)
Fixes #18513. ## Description This PR attempts to fix the support for Anthropic Claude v3 models in BedrockChat LLM. The changes here has updated the payload to use the `messages` format instead of the formatted text prompt for all models; `messages` API is backwards compatible with all models in Anthropic, so this should not break the experience for any models. ## Notes The PR in the current form does not support the v3 models for the non-chat Bedrock LLM. This means, that with these changes, users won't be able to able to use the v3 models with the Bedrock LLM. I can open a separate PR to tackle this use-case, the intent here was to get this out quickly, so users can start using and test the chat LLM. The Bedrock LLM classes have also grown complex with a lot of conditions to support various providers and models, and is ripe for a refactor to make future changes more palatable. This refactor is likely to take longer, and requires more thorough testing from the community. Credit to PRs [18579](https://github.com/langchain-ai/langchain/pull/18579) and [18548](https://github.com/langchain-ai/langchain/pull/18548) for some of the code here. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
import re
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -19,6 +20,110 @@ from langchain_community.utilities.anthropic import (
|
||||
)
|
||||
|
||||
|
||||
def _format_image(image_url: str) -> Dict:
|
||||
"""
|
||||
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||
to a dict for anthropic api
|
||||
|
||||
{
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "/9j/4AAQSkZJRg...",
|
||||
}
|
||||
|
||||
And throws an error if it's not a b64 image
|
||||
"""
|
||||
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
|
||||
match = re.match(regex, image_url)
|
||||
if match is None:
|
||||
raise ValueError(
|
||||
"Anthropic only supports base64-encoded images currently."
|
||||
" Example: data:image/png;base64,'/9j/4AAQSk'..."
|
||||
)
|
||||
return {
|
||||
"type": "base64",
|
||||
"media_type": match.group("media_type"),
|
||||
"data": match.group("data"),
|
||||
}
|
||||
|
||||
|
||||
def _format_anthropic_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> Tuple[Optional[str], List[Dict]]:
|
||||
"""Format messages for anthropic."""
|
||||
|
||||
"""
|
||||
[
|
||||
{
|
||||
"role": _message_type_lookups[m.type],
|
||||
"content": [_AnthropicMessageContent(text=m.content).dict()],
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
"""
|
||||
system: Optional[str] = None
|
||||
formatted_messages: List[Dict] = []
|
||||
for i, message in enumerate(messages):
|
||||
if message.type == "system":
|
||||
if i != 0:
|
||||
raise ValueError("System message must be at beginning of message list.")
|
||||
if not isinstance(message.content, str):
|
||||
raise ValueError(
|
||||
"System message must be a string, "
|
||||
f"instead was: {type(message.content)}"
|
||||
)
|
||||
system = message.content
|
||||
continue
|
||||
|
||||
role = _message_type_lookups[message.type]
|
||||
content: Union[str, List[Dict]]
|
||||
|
||||
if not isinstance(message.content, str):
|
||||
# parse as dict
|
||||
assert isinstance(
|
||||
message.content, list
|
||||
), "Anthropic message content must be str or list of dicts"
|
||||
|
||||
# populate content
|
||||
content = []
|
||||
for item in message.content:
|
||||
if isinstance(item, str):
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": item,
|
||||
}
|
||||
)
|
||||
elif isinstance(item, dict):
|
||||
if "type" not in item:
|
||||
raise ValueError("Dict content item must have a type key")
|
||||
if item["type"] == "image_url":
|
||||
# convert format
|
||||
source = _format_image(item["image_url"]["url"])
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
else:
|
||||
content.append(item)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Content items must be str or dict, instead was: {type(item)}"
|
||||
)
|
||||
else:
|
||||
content = message.content
|
||||
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
return system, formatted_messages
|
||||
|
||||
|
||||
class ChatPromptAdapter:
|
||||
"""Adapter class to prepare the inputs from Langchain to prompt format
|
||||
that Chat model expects.
|
||||
@@ -44,6 +149,20 @@ class ChatPromptAdapter:
|
||||
)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def format_messages(
|
||||
cls, provider: str, messages: List[BaseMessage]
|
||||
) -> Tuple[Optional[str], List[Dict]]:
|
||||
if provider == "anthropic":
|
||||
return _format_anthropic_messages(messages)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Provider {provider} not supported for format_messages"
|
||||
)
|
||||
|
||||
|
||||
_message_type_lookups = {"human": "user", "ai": "assistant"}
|
||||
|
||||
|
||||
class BedrockChat(BaseChatModel, BedrockBase):
|
||||
"""A chat model that uses the Bedrock API."""
|
||||
@@ -85,12 +204,25 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
provider = self._get_provider()
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
system = None
|
||||
formatted_messages = None
|
||||
if provider == "anthropic":
|
||||
prompt = None
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
else:
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
for chunk in self._prepare_input_and_invoke_stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
messages=formatted_messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
delta = chunk.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
@@ -109,20 +241,34 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
completion += chunk.text
|
||||
else:
|
||||
provider = self._get_provider()
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
system = None
|
||||
formatted_messages = None
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
if provider == "anthropic":
|
||||
prompt = None
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
else:
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
completion = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
system=system,
|
||||
messages=formatted_messages,
|
||||
**params,
|
||||
)
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=completion))]
|
||||
)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
|
Reference in New Issue
Block a user