mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
feat: add bedrock chat model (#8017)
Replace this comment with: - Description: Add Bedrock implementation of Anthropic Claude for Chat - Tag maintainer: @hwchase17, @baskaryan - Twitter handle: @bwmatson --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a7c9bd30d4
commit
58d7d86e51
106
docs/extras/integrations/chat/bedrock.ipynb
Normal file
106
docs/extras/integrations/chat/bedrock.ipynb
Normal file
@ -0,0 +1,106 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Bedrock Chat\n",
|
||||
"\n",
|
||||
"[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d51edc81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install boto3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import BedrockChat\n",
|
||||
"from langchain.schema import HumanMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = BedrockChat(model_id=\"anthropic.claude-v2\", model_kwargs={\"temperature\":0.1})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\" Voici la traduction en français : J'adore programmer.\", additional_kwargs={}, example=False)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c253883f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.chat_models.anyscale import ChatAnyscale
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from langchain.chat_models.bedrock import BedrockChat
|
||||
from langchain.chat_models.ernie import ErnieBotChat
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||
@ -35,6 +36,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
|
||||
__all__ = [
|
||||
"ChatOpenAI",
|
||||
"AzureChatOpenAI",
|
||||
"BedrockChat",
|
||||
"FakeListChatModel",
|
||||
"PromptLayerChatOpenAI",
|
||||
"ChatAnthropic",
|
||||
|
@ -6,10 +6,6 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.anthropic import _AnthropicCommon
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -18,7 +14,54 @@ from langchain.schema.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
def _convert_one_message_to_text(
|
||||
message: BaseMessage,
|
||||
human_prompt: str,
|
||||
ai_prompt: str,
|
||||
) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{human_prompt} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{ai_prompt} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"{human_prompt} <admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
|
||||
def convert_messages_to_prompt_anthropic(
|
||||
messages: List[BaseMessage],
|
||||
*,
|
||||
human_prompt: str = "\n\nHuman:",
|
||||
ai_prompt: str = "\n\nAssistant:",
|
||||
) -> str:
|
||||
"""Format a list of messages into a full prompt for the Anthropic model
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
|
||||
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
|
||||
Returns:
|
||||
str: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
if not isinstance(messages[-1], AIMessage):
|
||||
messages.append(AIMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
_convert_one_message_to_text(message, human_prompt, ai_prompt)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
|
||||
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
@ -55,52 +98,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of strings into a single string with necessary newlines.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
str: Combined string with necessary newlines.
|
||||
"""
|
||||
return "".join(
|
||||
self._convert_one_message_to_text(message) for message in messages
|
||||
)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
||||
"""Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
if not self.AI_PROMPT:
|
||||
raise NameError("Please ensure the anthropic package is loaded")
|
||||
|
||||
if not isinstance(messages[-1], AIMessage):
|
||||
messages.append(AIMessage(content=""))
|
||||
text = self._convert_messages_to_text(messages)
|
||||
return (
|
||||
text.rstrip()
|
||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
||||
prompt_params = {}
|
||||
if self.HUMAN_PROMPT:
|
||||
prompt_params["human_prompt"] = self.HUMAN_PROMPT
|
||||
if self.AI_PROMPT:
|
||||
prompt_params["ai_prompt"] = self.AI_PROMPT
|
||||
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -152,7 +162,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
@ -177,7 +189,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
|
98
libs/langchain/langchain/chat_models/bedrock.py
Normal file
98
libs/langchain/langchain/chat_models/bedrock.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.bedrock import BedrockBase
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema.messages import AIMessage, BaseMessage
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
class ChatPromptAdapter:
|
||||
"""Adapter class to prepare the inputs from Langchain to prompt format
|
||||
that Chat model expects.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def convert_messages_to_prompt(
|
||||
cls, provider: str, messages: List[BaseMessage]
|
||||
) -> str:
|
||||
if provider == "anthropic":
|
||||
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Provider {provider} model does not support chat."
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
class BedrockChat(BaseChatModel, BedrockBase):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "amazon_bedrock_chat"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError(
|
||||
"""Bedrock doesn't support stream requests at the moment."""
|
||||
)
|
||||
|
||||
def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError(
|
||||
"""Bedrock doesn't support async requests at the moment."""
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
provider = self._get_provider()
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
completion = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
||||
)
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError(
|
||||
"""Bedrock doesn't support async stream requests at the moment."""
|
||||
)
|
@ -1,10 +1,11 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
|
||||
class LLMInputOutputAdapter:
|
||||
@ -47,33 +48,7 @@ class LLMInputOutputAdapter:
|
||||
return response_body.get("results")[0].get("outputText")
|
||||
|
||||
|
||||
class Bedrock(LLM):
|
||||
"""Bedrock models.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Bedrock service.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from bedrock_langchain.bedrock_llm import BedrockLLM
|
||||
|
||||
llm = BedrockLLM(
|
||||
credentials_profile_name="default",
|
||||
model_id="amazon.titan-tg1-large"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
class BedrockBase(BaseModel, ABC):
|
||||
client: Any #: :meta private:
|
||||
|
||||
region_name: Optional[str] = None
|
||||
@ -99,11 +74,6 @@ class Bedrock(LLM):
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Needed if you don't want to default to us-east-1 endpoint"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
@ -151,11 +121,77 @@ class Bedrock(LLM):
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
def _get_provider(self) -> str:
|
||||
return self.model_id.split(".")[0]
|
||||
|
||||
def _prepare_input_and_invoke(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
provider = self._get_provider()
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||
body = json.dumps(input_body)
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model(
|
||||
body=body, modelId=self.model_id, accept=accept, contentType=contentType
|
||||
)
|
||||
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class Bedrock(LLM, BedrockBase):
|
||||
"""Bedrock models.
|
||||
|
||||
To authenticate, the AWS client uses the following methods to
|
||||
automatically load credentials:
|
||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||
|
||||
If a specific credential profile should be used, you must pass
|
||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||
|
||||
Make sure the credentials / roles used have the required policies to
|
||||
access the Bedrock service.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from bedrock_langchain.bedrock_llm import BedrockLLM
|
||||
|
||||
llm = BedrockLLM(
|
||||
credentials_profile_name="default",
|
||||
model_id="amazon.titan-tg1-large"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "amazon_bedrock"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -177,25 +213,7 @@ class Bedrock(LLM):
|
||||
|
||||
response = se("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
provider = self.model_id.split(".")[0]
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||
body = json.dumps(input_body)
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model(
|
||||
body=body, modelId=self.model_id, accept=accept, contentType=contentType
|
||||
)
|
||||
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||
|
||||
return text
|
||||
|
@ -4,11 +4,11 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
LLMResult,
|
||||
from langchain.chat_models.anthropic import (
|
||||
ChatAnthropic,
|
||||
convert_messages_to_prompt_anthropic,
|
||||
)
|
||||
from langchain.schema import ChatGeneration, LLMResult
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
@ -86,14 +86,12 @@ async def test_anthropic_async_streaming_callback() -> None:
|
||||
|
||||
|
||||
def test_formatting() -> None:
|
||||
chat = ChatAnthropic()
|
||||
|
||||
chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
|
||||
result = chat._convert_messages_to_prompt(chat_messages)
|
||||
messages: List[BaseMessage] = [HumanMessage(content="Hello")]
|
||||
result = convert_messages_to_prompt_anthropic(messages)
|
||||
assert result == "\n\nHuman: Hello\n\nAssistant:"
|
||||
|
||||
chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
|
||||
result = chat._convert_messages_to_prompt(chat_messages)
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
|
||||
result = convert_messages_to_prompt_anthropic(messages)
|
||||
assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user