mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 16:11:02 +00:00
Support for claude v3 models. (#18630)
Fixes #18513. ## Description This PR attempts to fix the support for Anthropic Claude v3 models in BedrockChat LLM. The changes here has updated the payload to use the `messages` format instead of the formatted text prompt for all models; `messages` API is backwards compatible with all models in Anthropic, so this should not break the experience for any models. ## Notes The PR in the current form does not support the v3 models for the non-chat Bedrock LLM. This means, that with these changes, users won't be able to able to use the v3 models with the Bedrock LLM. I can open a separate PR to tackle this use-case, the intent here was to get this out quickly, so users can start using and test the chat LLM. The Bedrock LLM classes have also grown complex with a lot of conditions to support various providers and models, and is ripe for a refactor to make future changes more palatable. This refactor is likely to take longer, and requires more thorough testing from the community. Credit to PRs [18579](https://github.com/langchain-ai/langchain/pull/18579) and [18548](https://github.com/langchain-ai/langchain/pull/18548) for some of the code here. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
1b4dcf22f3
commit
2b234a4d96
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Dict, Iterator, List, Optional
|
import re
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -19,6 +20,110 @@ from langchain_community.utilities.anthropic import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_image(image_url: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||||
|
to a dict for anthropic api
|
||||||
|
|
||||||
|
{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": "/9j/4AAQSkZJRg...",
|
||||||
|
}
|
||||||
|
|
||||||
|
And throws an error if it's not a b64 image
|
||||||
|
"""
|
||||||
|
regex = r"^data:(?P<media_type>image/.+);base64,(?P<data>.+)$"
|
||||||
|
match = re.match(regex, image_url)
|
||||||
|
if match is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Anthropic only supports base64-encoded images currently."
|
||||||
|
" Example: data:image/png;base64,'/9j/4AAQSk'..."
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": match.group("media_type"),
|
||||||
|
"data": match.group("data"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_anthropic_messages(
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
) -> Tuple[Optional[str], List[Dict]]:
|
||||||
|
"""Format messages for anthropic."""
|
||||||
|
|
||||||
|
"""
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": _message_type_lookups[m.type],
|
||||||
|
"content": [_AnthropicMessageContent(text=m.content).dict()],
|
||||||
|
}
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
system: Optional[str] = None
|
||||||
|
formatted_messages: List[Dict] = []
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
if message.type == "system":
|
||||||
|
if i != 0:
|
||||||
|
raise ValueError("System message must be at beginning of message list.")
|
||||||
|
if not isinstance(message.content, str):
|
||||||
|
raise ValueError(
|
||||||
|
"System message must be a string, "
|
||||||
|
f"instead was: {type(message.content)}"
|
||||||
|
)
|
||||||
|
system = message.content
|
||||||
|
continue
|
||||||
|
|
||||||
|
role = _message_type_lookups[message.type]
|
||||||
|
content: Union[str, List[Dict]]
|
||||||
|
|
||||||
|
if not isinstance(message.content, str):
|
||||||
|
# parse as dict
|
||||||
|
assert isinstance(
|
||||||
|
message.content, list
|
||||||
|
), "Anthropic message content must be str or list of dicts"
|
||||||
|
|
||||||
|
# populate content
|
||||||
|
content = []
|
||||||
|
for item in message.content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": item,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
if "type" not in item:
|
||||||
|
raise ValueError("Dict content item must have a type key")
|
||||||
|
if item["type"] == "image_url":
|
||||||
|
# convert format
|
||||||
|
source = _format_image(item["image_url"]["url"])
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": source,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content.append(item)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Content items must be str or dict, instead was: {type(item)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content = message.content
|
||||||
|
|
||||||
|
formatted_messages.append(
|
||||||
|
{
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return system, formatted_messages
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptAdapter:
|
class ChatPromptAdapter:
|
||||||
"""Adapter class to prepare the inputs from Langchain to prompt format
|
"""Adapter class to prepare the inputs from Langchain to prompt format
|
||||||
that Chat model expects.
|
that Chat model expects.
|
||||||
@ -44,6 +149,20 @@ class ChatPromptAdapter:
|
|||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def format_messages(
|
||||||
|
cls, provider: str, messages: List[BaseMessage]
|
||||||
|
) -> Tuple[Optional[str], List[Dict]]:
|
||||||
|
if provider == "anthropic":
|
||||||
|
return _format_anthropic_messages(messages)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Provider {provider} not supported for format_messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_message_type_lookups = {"human": "user", "ai": "assistant"}
|
||||||
|
|
||||||
|
|
||||||
class BedrockChat(BaseChatModel, BedrockBase):
|
class BedrockChat(BaseChatModel, BedrockBase):
|
||||||
"""A chat model that uses the Bedrock API."""
|
"""A chat model that uses the Bedrock API."""
|
||||||
@ -85,12 +204,25 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
provider = self._get_provider()
|
provider = self._get_provider()
|
||||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
system = None
|
||||||
provider=provider, messages=messages
|
formatted_messages = None
|
||||||
)
|
if provider == "anthropic":
|
||||||
|
prompt = None
|
||||||
|
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||||
|
provider, messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||||
|
provider=provider, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
for chunk in self._prepare_input_and_invoke_stream(
|
for chunk in self._prepare_input_and_invoke_stream(
|
||||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
prompt=prompt,
|
||||||
|
system=system,
|
||||||
|
messages=formatted_messages,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
delta = chunk.text
|
delta = chunk.text
|
||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
@ -109,20 +241,34 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
|||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
else:
|
else:
|
||||||
provider = self._get_provider()
|
provider = self._get_provider()
|
||||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
system = None
|
||||||
provider=provider, messages=messages
|
formatted_messages = None
|
||||||
)
|
|
||||||
|
|
||||||
params: Dict[str, Any] = {**kwargs}
|
params: Dict[str, Any] = {**kwargs}
|
||||||
|
if provider == "anthropic":
|
||||||
|
prompt = None
|
||||||
|
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||||
|
provider, messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||||
|
provider=provider, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
params["stop_sequences"] = stop
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
completion = self._prepare_input_and_invoke(
|
completion = self._prepare_input_and_invoke(
|
||||||
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
prompt=prompt,
|
||||||
|
stop=stop,
|
||||||
|
run_manager=run_manager,
|
||||||
|
system=system,
|
||||||
|
messages=formatted_messages,
|
||||||
|
**params,
|
||||||
)
|
)
|
||||||
|
|
||||||
message = AIMessage(content=completion)
|
return ChatResult(
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
generations=[ChatGeneration(message=AIMessage(content=completion))]
|
||||||
|
)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
if self._model_is_anthropic:
|
if self._model_is_anthropic:
|
||||||
|
@ -77,6 +77,20 @@ def _human_assistant_format(input_text: str) -> str:
|
|||||||
return input_text
|
return input_text
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_response_to_generation_chunk(
|
||||||
|
stream_response: Dict[str, Any],
|
||||||
|
) -> GenerationChunk:
|
||||||
|
"""Convert a stream response to a generation chunk."""
|
||||||
|
if not stream_response["delta"]:
|
||||||
|
return GenerationChunk(text="")
|
||||||
|
return GenerationChunk(
|
||||||
|
text=stream_response["delta"]["text"],
|
||||||
|
generation_info=dict(
|
||||||
|
finish_reason=stream_response.get("stop_reason", None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMInputOutputAdapter:
|
class LLMInputOutputAdapter:
|
||||||
"""Adapter class to prepare the inputs from Langchain to a format
|
"""Adapter class to prepare the inputs from Langchain to a format
|
||||||
that LLM model expects.
|
that LLM model expects.
|
||||||
@ -93,11 +107,26 @@ class LLMInputOutputAdapter:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare_input(
|
def prepare_input(
|
||||||
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
|
cls,
|
||||||
|
provider: str,
|
||||||
|
model_kwargs: Dict[str, Any],
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
system: Optional[str] = None,
|
||||||
|
messages: Optional[List[Dict]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
input_body = {**model_kwargs}
|
input_body = {**model_kwargs}
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
input_body["prompt"] = _human_assistant_format(prompt)
|
if messages:
|
||||||
|
input_body["anthropic_version"] = "bedrock-2023-05-31"
|
||||||
|
input_body["messages"] = messages
|
||||||
|
if system:
|
||||||
|
input_body["system"] = system
|
||||||
|
if "max_tokens" not in input_body:
|
||||||
|
input_body["max_tokens"] = 1024
|
||||||
|
if prompt:
|
||||||
|
input_body["prompt"] = _human_assistant_format(prompt)
|
||||||
|
if "max_tokens_to_sample" not in input_body:
|
||||||
|
input_body["max_tokens_to_sample"] = 1024
|
||||||
elif provider in ("ai21", "cohere", "meta"):
|
elif provider in ("ai21", "cohere", "meta"):
|
||||||
input_body["prompt"] = prompt
|
input_body["prompt"] = prompt
|
||||||
elif provider == "amazon":
|
elif provider == "amazon":
|
||||||
@ -107,16 +136,17 @@ class LLMInputOutputAdapter:
|
|||||||
else:
|
else:
|
||||||
input_body["inputText"] = prompt
|
input_body["inputText"] = prompt
|
||||||
|
|
||||||
if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
|
|
||||||
input_body["max_tokens_to_sample"] = 256
|
|
||||||
|
|
||||||
return input_body
|
return input_body
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare_output(cls, provider: str, response: Any) -> dict:
|
def prepare_output(cls, provider: str, response: Any) -> dict:
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
response_body = json.loads(response.get("body").read().decode())
|
response_body = json.loads(response.get("body").read().decode())
|
||||||
text = response_body.get("completion")
|
if "completion" in response_body:
|
||||||
|
text = response_body.get("completion")
|
||||||
|
elif "content" in response_body:
|
||||||
|
content = response_body.get("content")
|
||||||
|
text = content[0].get("text")
|
||||||
else:
|
else:
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
|
|
||||||
@ -136,14 +166,21 @@ class LLMInputOutputAdapter:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare_output_stream(
|
def prepare_output_stream(
|
||||||
cls, provider: str, response: Any, stop: Optional[List[str]] = None
|
cls,
|
||||||
|
provider: str,
|
||||||
|
response: Any,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
messages_api: bool = False,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
stream = response.get("body")
|
stream = response.get("body")
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
return
|
return
|
||||||
|
|
||||||
output_key = cls.provider_to_output_key_map.get(provider, None)
|
if messages_api:
|
||||||
|
output_key = "message"
|
||||||
|
else:
|
||||||
|
output_key = cls.provider_to_output_key_map.get(provider, "")
|
||||||
|
|
||||||
if not output_key:
|
if not output_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -161,15 +198,29 @@ class LLMInputOutputAdapter:
|
|||||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
elif messages_api and (chunk_obj.get("type") == "content_block_stop"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if messages_api and chunk_obj.get("type") in (
|
||||||
|
"message_start",
|
||||||
|
"content_block_start",
|
||||||
|
"content_block_delta",
|
||||||
|
):
|
||||||
|
if chunk_obj.get("type") == "content_block_delta":
|
||||||
|
chk = _stream_response_to_generation_chunk(chunk_obj)
|
||||||
|
yield chk
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
# chunk obj format varies with provider
|
# chunk obj format varies with provider
|
||||||
yield GenerationChunk(
|
yield GenerationChunk(
|
||||||
text=chunk_obj[output_key],
|
text=chunk_obj[output_key],
|
||||||
generation_info={
|
generation_info={
|
||||||
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
|
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
|
||||||
if GUARDRAILS_BODY_KEY in chunk_obj
|
if GUARDRAILS_BODY_KEY in chunk_obj
|
||||||
else None,
|
else None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def aprepare_output_stream(
|
async def aprepare_output_stream(
|
||||||
@ -412,7 +463,9 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
def _prepare_input_and_invoke(
|
def _prepare_input_and_invoke(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: Optional[str] = None,
|
||||||
|
system: Optional[str] = None,
|
||||||
|
messages: Optional[List[Dict]] = None,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -423,7 +476,13 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
params = {**_model_kwargs, **kwargs}
|
params = {**_model_kwargs, **kwargs}
|
||||||
if self._guardrails_enabled:
|
if self._guardrails_enabled:
|
||||||
params.update(self._get_guardrails_canonical())
|
params.update(self._get_guardrails_canonical())
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
input_body = LLMInputOutputAdapter.prepare_input(
|
||||||
|
provider=provider,
|
||||||
|
model_kwargs=params,
|
||||||
|
prompt=prompt,
|
||||||
|
system=system,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
@ -498,7 +557,9 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
def _prepare_input_and_invoke_stream(
|
def _prepare_input_and_invoke_stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: Optional[str] = None,
|
||||||
|
system: Optional[str] = None,
|
||||||
|
messages: Optional[List[Dict]] = None,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -524,7 +585,13 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
if self._guardrails_enabled:
|
if self._guardrails_enabled:
|
||||||
params.update(self._get_guardrails_canonical())
|
params.update(self._get_guardrails_canonical())
|
||||||
|
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
input_body = LLMInputOutputAdapter.prepare_input(
|
||||||
|
provider=provider,
|
||||||
|
prompt=prompt,
|
||||||
|
system=system,
|
||||||
|
messages=messages,
|
||||||
|
model_kwargs=params,
|
||||||
|
)
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
request_options = {
|
request_options = {
|
||||||
@ -546,7 +613,7 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||||
|
|
||||||
for chunk in LLMInputOutputAdapter.prepare_output_stream(
|
for chunk in LLMInputOutputAdapter.prepare_output_stream(
|
||||||
provider, response, stop
|
provider, response, stop, True if messages else False
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
# verify and raise callback error if any middleware intervened
|
# verify and raise callback error if any middleware intervened
|
||||||
@ -576,7 +643,9 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
_model_kwargs["stream"] = True
|
_model_kwargs["stream"] = True
|
||||||
|
|
||||||
params = {**_model_kwargs, **kwargs}
|
params = {**_model_kwargs, **kwargs}
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
input_body = LLMInputOutputAdapter.prepare_input(
|
||||||
|
provider=provider, prompt=prompt, model_kwargs=params
|
||||||
|
)
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
response = await asyncio.get_running_loop().run_in_executor(
|
response = await asyncio.get_running_loop().run_in_executor(
|
||||||
@ -629,6 +698,17 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
model_id = values["model_id"]
|
||||||
|
if model_id.startswith("anthropic.claude-3"):
|
||||||
|
raise ValueError(
|
||||||
|
"Claude v3 models are not supported by this LLM."
|
||||||
|
"Please use `from langchain_community.chat_models import BedrockChat` "
|
||||||
|
"instead."
|
||||||
|
)
|
||||||
|
return super().validate_environment(values)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
|
Loading…
Reference in New Issue
Block a user