mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
feat: add bedrock chat model (#8017)
Replace this comment with: - Description: Add Bedrock implementation of Anthropic Claude for Chat - Tag maintainer: @hwchase17, @baskaryan - Twitter handle: @bwmatson --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a7c9bd30d4
commit
58d7d86e51
106
docs/extras/integrations/chat/bedrock.ipynb
Normal file
106
docs/extras/integrations/chat/bedrock.ipynb
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Bedrock Chat\n",
|
||||||
|
"\n",
|
||||||
|
"[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d51edc81",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install boto3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import BedrockChat\n",
|
||||||
|
"from langchain.schema import HumanMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = BedrockChat(model_id=\"anthropic.claude-v2\", model_kwargs={\"temperature\":0.1})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\" Voici la traduction en français : J'adore programmer.\", additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c253883f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs.
|
|||||||
from langchain.chat_models.anthropic import ChatAnthropic
|
from langchain.chat_models.anthropic import ChatAnthropic
|
||||||
from langchain.chat_models.anyscale import ChatAnyscale
|
from langchain.chat_models.anyscale import ChatAnyscale
|
||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
from langchain.chat_models.bedrock import BedrockChat
|
||||||
from langchain.chat_models.ernie import ErnieBotChat
|
from langchain.chat_models.ernie import ErnieBotChat
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.chat_models.google_palm import ChatGooglePalm
|
from langchain.chat_models.google_palm import ChatGooglePalm
|
||||||
@ -35,6 +36,7 @@ from langchain.chat_models.vertexai import ChatVertexAI
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
"AzureChatOpenAI",
|
"AzureChatOpenAI",
|
||||||
|
"BedrockChat",
|
||||||
"FakeListChatModel",
|
"FakeListChatModel",
|
||||||
"PromptLayerChatOpenAI",
|
"PromptLayerChatOpenAI",
|
||||||
"ChatAnthropic",
|
"ChatAnthropic",
|
||||||
|
@ -6,10 +6,6 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms.anthropic import _AnthropicCommon
|
from langchain.llms.anthropic import _AnthropicCommon
|
||||||
from langchain.schema import (
|
|
||||||
ChatGeneration,
|
|
||||||
ChatResult,
|
|
||||||
)
|
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -18,7 +14,54 @@ from langchain.schema.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.schema.output import ChatGenerationChunk
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_one_message_to_text(
|
||||||
|
message: BaseMessage,
|
||||||
|
human_prompt: str,
|
||||||
|
ai_prompt: str,
|
||||||
|
) -> str:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_text = f"{human_prompt} {message.content}"
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_text = f"{ai_prompt} {message.content}"
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_text = f"{human_prompt} <admin>{message.content}</admin>"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
|
||||||
|
def convert_messages_to_prompt_anthropic(
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
*,
|
||||||
|
human_prompt: str = "\n\nHuman:",
|
||||||
|
ai_prompt: str = "\n\nAssistant:",
|
||||||
|
) -> str:
|
||||||
|
"""Format a list of messages into a full prompt for the Anthropic model
|
||||||
|
Args:
|
||||||
|
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||||
|
human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
|
||||||
|
ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
|
||||||
|
Returns:
|
||||||
|
str: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = messages.copy() # don't mutate the original list
|
||||||
|
|
||||||
|
if not isinstance(messages[-1], AIMessage):
|
||||||
|
messages.append(AIMessage(content=""))
|
||||||
|
|
||||||
|
text = "".join(
|
||||||
|
_convert_one_message_to_text(message, human_prompt, ai_prompt)
|
||||||
|
for message in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||||
|
return text.rstrip()
|
||||||
|
|
||||||
|
|
||||||
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||||
@ -55,52 +98,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
def lc_serializable(self) -> bool:
|
def lc_serializable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
|
||||||
if isinstance(message, ChatMessage):
|
|
||||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
|
||||||
elif isinstance(message, HumanMessage):
|
|
||||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
|
||||||
elif isinstance(message, AIMessage):
|
|
||||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
|
||||||
elif isinstance(message, SystemMessage):
|
|
||||||
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown type {message}")
|
|
||||||
return message_text
|
|
||||||
|
|
||||||
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
|
|
||||||
"""Format a list of strings into a single string with necessary newlines.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Combined string with necessary newlines.
|
|
||||||
"""
|
|
||||||
return "".join(
|
|
||||||
self._convert_one_message_to_text(message) for message in messages
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
|
||||||
"""Format a list of messages into a full prompt for the Anthropic model
|
"""Format a list of messages into a full prompt for the Anthropic model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
|
||||||
"""
|
"""
|
||||||
messages = messages.copy() # don't mutate the original list
|
prompt_params = {}
|
||||||
|
if self.HUMAN_PROMPT:
|
||||||
if not self.AI_PROMPT:
|
prompt_params["human_prompt"] = self.HUMAN_PROMPT
|
||||||
raise NameError("Please ensure the anthropic package is loaded")
|
if self.AI_PROMPT:
|
||||||
|
prompt_params["ai_prompt"] = self.AI_PROMPT
|
||||||
if not isinstance(messages[-1], AIMessage):
|
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)
|
||||||
messages.append(AIMessage(content=""))
|
|
||||||
text = self._convert_messages_to_text(messages)
|
|
||||||
return (
|
|
||||||
text.rstrip()
|
|
||||||
) # trim off the trailing ' ' that might come from the "Assistant: "
|
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -152,7 +162,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
else:
|
else:
|
||||||
prompt = self._convert_messages_to_prompt(messages)
|
prompt = self._convert_messages_to_prompt(
|
||||||
|
messages,
|
||||||
|
)
|
||||||
params: Dict[str, Any] = {
|
params: Dict[str, Any] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**self._default_params,
|
**self._default_params,
|
||||||
@ -177,7 +189,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
|||||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
else:
|
else:
|
||||||
prompt = self._convert_messages_to_prompt(messages)
|
prompt = self._convert_messages_to_prompt(
|
||||||
|
messages,
|
||||||
|
)
|
||||||
params: Dict[str, Any] = {
|
params: Dict[str, Any] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**self._default_params,
|
**self._default_params,
|
||||||
|
98
libs/langchain/langchain/chat_models/bedrock.py
Normal file
98
libs/langchain/langchain/chat_models/bedrock.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.llms.bedrock import BedrockBase
|
||||||
|
from langchain.pydantic_v1 import Extra
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage
|
||||||
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPromptAdapter:
|
||||||
|
"""Adapter class to prepare the inputs from Langchain to prompt format
|
||||||
|
that Chat model expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_messages_to_prompt(
|
||||||
|
cls, provider: str, messages: List[BaseMessage]
|
||||||
|
) -> str:
|
||||||
|
if provider == "anthropic":
|
||||||
|
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Provider {provider} model does not support chat."
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockChat(BaseChatModel, BedrockBase):
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "amazon_bedrock_chat"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""Bedrock doesn't support stream requests at the moment."""
|
||||||
|
)
|
||||||
|
|
||||||
|
def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""Bedrock doesn't support async requests at the moment."""
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
provider = self._get_provider()
|
||||||
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||||
|
provider=provider, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
params: Dict[str, Any] = {**kwargs}
|
||||||
|
if stop:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
completion = self._prepare_input_and_invoke(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
|
|
||||||
|
message = AIMessage(content=completion)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""Bedrock doesn't support async stream requests at the moment."""
|
||||||
|
)
|
@ -1,10 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
|
from abc import ABC
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
|
||||||
class LLMInputOutputAdapter:
|
class LLMInputOutputAdapter:
|
||||||
@ -47,33 +48,7 @@ class LLMInputOutputAdapter:
|
|||||||
return response_body.get("results")[0].get("outputText")
|
return response_body.get("results")[0].get("outputText")
|
||||||
|
|
||||||
|
|
||||||
class Bedrock(LLM):
|
class BedrockBase(BaseModel, ABC):
|
||||||
"""Bedrock models.
|
|
||||||
|
|
||||||
To authenticate, the AWS client uses the following methods to
|
|
||||||
automatically load credentials:
|
|
||||||
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
|
||||||
|
|
||||||
If a specific credential profile should be used, you must pass
|
|
||||||
the name of the profile from the ~/.aws/credentials file that is to be used.
|
|
||||||
|
|
||||||
Make sure the credentials / roles used have the required policies to
|
|
||||||
access the Bedrock service.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from bedrock_langchain.bedrock_llm import BedrockLLM
|
|
||||||
|
|
||||||
llm = BedrockLLM(
|
|
||||||
credentials_profile_name="default",
|
|
||||||
model_id="amazon.titan-tg1-large"
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
|
|
||||||
region_name: Optional[str] = None
|
region_name: Optional[str] = None
|
||||||
@ -99,11 +74,6 @@ class Bedrock(LLM):
|
|||||||
endpoint_url: Optional[str] = None
|
endpoint_url: Optional[str] = None
|
||||||
"""Needed if you don't want to default to us-east-1 endpoint"""
|
"""Needed if you don't want to default to us-east-1 endpoint"""
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that AWS credentials to and python package exists in environment."""
|
"""Validate that AWS credentials to and python package exists in environment."""
|
||||||
@ -151,11 +121,77 @@ class Bedrock(LLM):
|
|||||||
**{"model_kwargs": _model_kwargs},
|
**{"model_kwargs": _model_kwargs},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_provider(self) -> str:
|
||||||
|
return self.model_id.split(".")[0]
|
||||||
|
|
||||||
|
def _prepare_input_and_invoke(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
|
provider = self._get_provider()
|
||||||
|
params = {**_model_kwargs, **kwargs}
|
||||||
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
accept = "application/json"
|
||||||
|
contentType = "application/json"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body, modelId=self.model_id, accept=accept, contentType=contentType
|
||||||
|
)
|
||||||
|
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class Bedrock(LLM, BedrockBase):
|
||||||
|
"""Bedrock models.
|
||||||
|
|
||||||
|
To authenticate, the AWS client uses the following methods to
|
||||||
|
automatically load credentials:
|
||||||
|
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
|
|
||||||
|
If a specific credential profile should be used, you must pass
|
||||||
|
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||||
|
|
||||||
|
Make sure the credentials / roles used have the required policies to
|
||||||
|
access the Bedrock service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from bedrock_langchain.bedrock_llm import BedrockLLM
|
||||||
|
|
||||||
|
llm = BedrockLLM(
|
||||||
|
credentials_profile_name="default",
|
||||||
|
model_id="amazon.titan-tg1-large"
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "amazon_bedrock"
|
return "amazon_bedrock"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -177,25 +213,7 @@ class Bedrock(LLM):
|
|||||||
|
|
||||||
response = se("Tell me a joke.")
|
response = se("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
|
||||||
|
|
||||||
provider = self.model_id.split(".")[0]
|
text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||||
params = {**_model_kwargs, **kwargs}
|
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
|
||||||
body = json.dumps(input_body)
|
|
||||||
accept = "application/json"
|
|
||||||
contentType = "application/json"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self.client.invoke_model(
|
|
||||||
body=body, modelId=self.model_id, accept=accept, contentType=contentType
|
|
||||||
)
|
|
||||||
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
|
||||||
|
|
||||||
if stop is not None:
|
|
||||||
text = enforce_stop_tokens(text, stop)
|
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
@ -4,11 +4,11 @@ from typing import List
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
from langchain.chat_models.anthropic import ChatAnthropic
|
from langchain.chat_models.anthropic import (
|
||||||
from langchain.schema import (
|
ChatAnthropic,
|
||||||
ChatGeneration,
|
convert_messages_to_prompt_anthropic,
|
||||||
LLMResult,
|
|
||||||
)
|
)
|
||||||
|
from langchain.schema import ChatGeneration, LLMResult
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
@ -86,14 +86,12 @@ async def test_anthropic_async_streaming_callback() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_formatting() -> None:
|
def test_formatting() -> None:
|
||||||
chat = ChatAnthropic()
|
messages: List[BaseMessage] = [HumanMessage(content="Hello")]
|
||||||
|
result = convert_messages_to_prompt_anthropic(messages)
|
||||||
chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
|
|
||||||
result = chat._convert_messages_to_prompt(chat_messages)
|
|
||||||
assert result == "\n\nHuman: Hello\n\nAssistant:"
|
assert result == "\n\nHuman: Hello\n\nAssistant:"
|
||||||
|
|
||||||
chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
|
messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
|
||||||
result = chat._convert_messages_to_prompt(chat_messages)
|
result = convert_messages_to_prompt_anthropic(messages)
|
||||||
assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"
|
assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user