From 2b234a4d96d66af47a8b51284141d74decfc2251 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 6 Mar 2024 15:46:18 -0800 Subject: [PATCH] Support for claude v3 models. (#18630) Fixes #18513. ## Description This PR attempts to fix the support for Anthropic Claude v3 models in BedrockChat LLM. The changes here has updated the payload to use the `messages` format instead of the formatted text prompt for all models; `messages` API is backwards compatible with all models in Anthropic, so this should not break the experience for any models. ## Notes The PR in the current form does not support the v3 models for the non-chat Bedrock LLM. This means, that with these changes, users won't be able to able to use the v3 models with the Bedrock LLM. I can open a separate PR to tackle this use-case, the intent here was to get this out quickly, so users can start using and test the chat LLM. The Bedrock LLM classes have also grown complex with a lot of conditions to support various providers and models, and is ripe for a refactor to make future changes more palatable. This refactor is likely to take longer, and requires more thorough testing from the community. Credit to PRs [18579](https://github.com/langchain-ai/langchain/pull/18579) and [18548](https://github.com/langchain-ai/langchain/pull/18548) for some of the code here. --------- Co-authored-by: Erick Friis --- .../chat_models/bedrock.py | 170 ++++++++++++++++-- .../langchain_community/llms/bedrock.py | 124 ++++++++++--- 2 files changed, 260 insertions(+), 34 deletions(-) diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py index 5538372272b..f9d87b6274a 100644 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ b/libs/community/langchain_community/chat_models/bedrock.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Iterator, List, Optional +import re +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from langchain_core.callbacks import ( CallbackManagerForLLMRun, @@ -19,6 +20,110 @@ from langchain_community.utilities.anthropic import ( ) +def _format_image(image_url: str) -> Dict: + """ + Formats an image of format data:image/jpeg;base64,{b64_string} + to a dict for anthropic api + + { + "type": "base64", + "media_type": "image/jpeg", + "data": "/9j/4AAQSkZJRg...", + } + + And throws an error if it's not a b64 image + """ + regex = r"^data:(?Pimage/.+);base64,(?P.+)$" + match = re.match(regex, image_url) + if match is None: + raise ValueError( + "Anthropic only supports base64-encoded images currently." + " Example: data:image/png;base64,'/9j/4AAQSk'..." + ) + return { + "type": "base64", + "media_type": match.group("media_type"), + "data": match.group("data"), + } + + +def _format_anthropic_messages( + messages: List[BaseMessage], +) -> Tuple[Optional[str], List[Dict]]: + """Format messages for anthropic.""" + + """ + [ + { + "role": _message_type_lookups[m.type], + "content": [_AnthropicMessageContent(text=m.content).dict()], + } + for m in messages + ] + """ + system: Optional[str] = None + formatted_messages: List[Dict] = [] + for i, message in enumerate(messages): + if message.type == "system": + if i != 0: + raise ValueError("System message must be at beginning of message list.") + if not isinstance(message.content, str): + raise ValueError( + "System message must be a string, " + f"instead was: {type(message.content)}" + ) + system = message.content + continue + + role = _message_type_lookups[message.type] + content: Union[str, List[Dict]] + + if not isinstance(message.content, str): + # parse as dict + assert isinstance( + message.content, list + ), "Anthropic message content must be str or list of dicts" + + # populate content + content = [] + for item in message.content: + if isinstance(item, str): + content.append( + { + "type": "text", + "text": item, + } + ) + elif isinstance(item, dict): + if "type" not in item: + raise ValueError("Dict content item must have a type key") + if item["type"] == "image_url": + # convert format + source = _format_image(item["image_url"]["url"]) + content.append( + { + "type": "image", + "source": source, + } + ) + else: + content.append(item) + else: + raise ValueError( + f"Content items must be str or dict, instead was: {type(item)}" + ) + else: + content = message.content + + formatted_messages.append( + { + "role": role, + "content": content, + } + ) + return system, formatted_messages + + class ChatPromptAdapter: """Adapter class to prepare the inputs from Langchain to prompt format that Chat model expects. @@ -44,6 +149,20 @@ class ChatPromptAdapter: ) return prompt + @classmethod + def format_messages( + cls, provider: str, messages: List[BaseMessage] + ) -> Tuple[Optional[str], List[Dict]]: + if provider == "anthropic": + return _format_anthropic_messages(messages) + + raise NotImplementedError( + f"Provider {provider} not supported for format_messages" + ) + + +_message_type_lookups = {"human": "user", "ai": "assistant"} + class BedrockChat(BaseChatModel, BedrockBase): """A chat model that uses the Bedrock API.""" @@ -85,12 +204,25 @@ class BedrockChat(BaseChatModel, BedrockBase): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: provider = self._get_provider() - prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages - ) + system = None + formatted_messages = None + if provider == "anthropic": + prompt = None + system, formatted_messages = ChatPromptAdapter.format_messages( + provider, messages + ) + else: + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages + ) for chunk in self._prepare_input_and_invoke_stream( - prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + prompt=prompt, + system=system, + messages=formatted_messages, + stop=stop, + run_manager=run_manager, + **kwargs, ): delta = chunk.text yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) @@ -109,20 +241,34 @@ class BedrockChat(BaseChatModel, BedrockBase): completion += chunk.text else: provider = self._get_provider() - prompt = ChatPromptAdapter.convert_messages_to_prompt( - provider=provider, messages=messages - ) - + system = None + formatted_messages = None params: Dict[str, Any] = {**kwargs} + if provider == "anthropic": + prompt = None + system, formatted_messages = ChatPromptAdapter.format_messages( + provider, messages + ) + else: + prompt = ChatPromptAdapter.convert_messages_to_prompt( + provider=provider, messages=messages + ) + if stop: params["stop_sequences"] = stop completion = self._prepare_input_and_invoke( - prompt=prompt, stop=stop, run_manager=run_manager, **params + prompt=prompt, + stop=stop, + run_manager=run_manager, + system=system, + messages=formatted_messages, + **params, ) - message = AIMessage(content=completion) - return ChatResult(generations=[ChatGeneration(message=message)]) + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content=completion))] + ) def get_num_tokens(self, text: str) -> int: if self._model_is_anthropic: diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index f4d7b1d69e1..d126995b93c 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -77,6 +77,20 @@ def _human_assistant_format(input_text: str) -> str: return input_text +def _stream_response_to_generation_chunk( + stream_response: Dict[str, Any], +) -> GenerationChunk: + """Convert a stream response to a generation chunk.""" + if not stream_response["delta"]: + return GenerationChunk(text="") + return GenerationChunk( + text=stream_response["delta"]["text"], + generation_info=dict( + finish_reason=stream_response.get("stop_reason", None), + ), + ) + + class LLMInputOutputAdapter: """Adapter class to prepare the inputs from Langchain to a format that LLM model expects. @@ -93,11 +107,26 @@ class LLMInputOutputAdapter: @classmethod def prepare_input( - cls, provider: str, prompt: str, model_kwargs: Dict[str, Any] + cls, + provider: str, + model_kwargs: Dict[str, Any], + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, ) -> Dict[str, Any]: input_body = {**model_kwargs} if provider == "anthropic": - input_body["prompt"] = _human_assistant_format(prompt) + if messages: + input_body["anthropic_version"] = "bedrock-2023-05-31" + input_body["messages"] = messages + if system: + input_body["system"] = system + if "max_tokens" not in input_body: + input_body["max_tokens"] = 1024 + if prompt: + input_body["prompt"] = _human_assistant_format(prompt) + if "max_tokens_to_sample" not in input_body: + input_body["max_tokens_to_sample"] = 1024 elif provider in ("ai21", "cohere", "meta"): input_body["prompt"] = prompt elif provider == "amazon": @@ -107,16 +136,17 @@ class LLMInputOutputAdapter: else: input_body["inputText"] = prompt - if provider == "anthropic" and "max_tokens_to_sample" not in input_body: - input_body["max_tokens_to_sample"] = 256 - return input_body @classmethod def prepare_output(cls, provider: str, response: Any) -> dict: if provider == "anthropic": response_body = json.loads(response.get("body").read().decode()) - text = response_body.get("completion") + if "completion" in response_body: + text = response_body.get("completion") + elif "content" in response_body: + content = response_body.get("content") + text = content[0].get("text") else: response_body = json.loads(response.get("body").read()) @@ -136,14 +166,21 @@ class LLMInputOutputAdapter: @classmethod def prepare_output_stream( - cls, provider: str, response: Any, stop: Optional[List[str]] = None + cls, + provider: str, + response: Any, + stop: Optional[List[str]] = None, + messages_api: bool = False, ) -> Iterator[GenerationChunk]: stream = response.get("body") if not stream: return - output_key = cls.provider_to_output_key_map.get(provider, None) + if messages_api: + output_key = "message" + else: + output_key = cls.provider_to_output_key_map.get(provider, "") if not output_key: raise ValueError( @@ -161,15 +198,29 @@ class LLMInputOutputAdapter: chunk_obj["is_finished"] or chunk_obj[output_key] == "" ): return + elif messages_api and (chunk_obj.get("type") == "content_block_stop"): + return + + if messages_api and chunk_obj.get("type") in ( + "message_start", + "content_block_start", + "content_block_delta", + ): + if chunk_obj.get("type") == "content_block_delta": + chk = _stream_response_to_generation_chunk(chunk_obj) + yield chk + else: + continue + else: # chunk obj format varies with provider - yield GenerationChunk( - text=chunk_obj[output_key], - generation_info={ - GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY) - if GUARDRAILS_BODY_KEY in chunk_obj - else None, - }, - ) + yield GenerationChunk( + text=chunk_obj[output_key], + generation_info={ + GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY) + if GUARDRAILS_BODY_KEY in chunk_obj + else None, + }, + ) @classmethod async def aprepare_output_stream( @@ -412,7 +463,9 @@ class BedrockBase(BaseModel, ABC): def _prepare_input_and_invoke( self, - prompt: str, + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, @@ -423,7 +476,13 @@ class BedrockBase(BaseModel, ABC): params = {**_model_kwargs, **kwargs} if self._guardrails_enabled: params.update(self._get_guardrails_canonical()) - input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + model_kwargs=params, + prompt=prompt, + system=system, + messages=messages, + ) body = json.dumps(input_body) accept = "application/json" contentType = "application/json" @@ -498,7 +557,9 @@ class BedrockBase(BaseModel, ABC): def _prepare_input_and_invoke_stream( self, - prompt: str, + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, @@ -524,7 +585,13 @@ class BedrockBase(BaseModel, ABC): if self._guardrails_enabled: params.update(self._get_guardrails_canonical()) - input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, + prompt=prompt, + system=system, + messages=messages, + model_kwargs=params, + ) body = json.dumps(input_body) request_options = { @@ -546,7 +613,7 @@ class BedrockBase(BaseModel, ABC): raise ValueError(f"Error raised by bedrock service: {e}") for chunk in LLMInputOutputAdapter.prepare_output_stream( - provider, response, stop + provider, response, stop, True if messages else False ): yield chunk # verify and raise callback error if any middleware intervened @@ -576,7 +643,9 @@ class BedrockBase(BaseModel, ABC): _model_kwargs["stream"] = True params = {**_model_kwargs, **kwargs} - input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) + input_body = LLMInputOutputAdapter.prepare_input( + provider=provider, prompt=prompt, model_kwargs=params + ) body = json.dumps(input_body) response = await asyncio.get_running_loop().run_in_executor( @@ -629,6 +698,17 @@ class Bedrock(LLM, BedrockBase): """ + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + model_id = values["model_id"] + if model_id.startswith("anthropic.claude-3"): + raise ValueError( + "Claude v3 models are not supported by this LLM." + "Please use `from langchain_community.chat_models import BedrockChat` " + "instead." + ) + return super().validate_environment(values) + @property def _llm_type(self) -> str: """Return type of llm."""