mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 02:43:37 +00:00
Amazon Bedrock Support Streaming (#10393)
### Description - Add support for streaming with `Bedrock` LLM and `BedrockChat` Chat Model. - Bedrock as of now supports streaming for the `anthropic.claude-*` and `amazon.titan-*` models only, hence support for those have been built. - Also increased the default `max_token_to_sample` for Bedrock `anthropic` model provider to `256` from `50` to keep in line with the `Anthropic` defaults. - Added examples for streaming responses to the bedrock example notebooks. **_NOTE:_**: This PR fixes the issues mentioned in #9897 and makes that PR redundant.
This commit is contained in:
parent
0749a642f5
commit
67c5950df3
@ -22,7 +22,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -73,13 +73,46 @@
|
||||
"chat(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "a4a4f4d4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### For BedrockChat with Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c253883f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"\n",
|
||||
"chat = BedrockChat(\n",
|
||||
" model_id=\"anthropic.claude-v2\",\n",
|
||||
" streaming=True,\n",
|
||||
" callbacks=[StreamingStdOutCallbackHandler()],\n",
|
||||
" model_kwargs={\"temperature\": 0.1},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d9e52838",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"chat(messages)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -98,7 +131,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -61,6 +61,46 @@
|
||||
"\n",
|
||||
"conversation.predict(input=\"Hi there!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Conversation Chain With Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import Bedrock\n",
|
||||
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm = Bedrock(\n",
|
||||
" credentials_profile_name=\"bedrock-admin\",\n",
|
||||
" model_id=\"amazon.titan-tg1-large\",\n",
|
||||
" streaming=True,\n",
|
||||
" callbacks=[StreamingStdOutCallbackHandler()],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"conversation = ConversationChain(\n",
|
||||
" llm=llm, verbose=True, memory=ConversationBufferMemory()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"conversation.predict(input=\"Hi there!\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
6
docs/package-lock.json
generated
6
docs/package-lock.json
generated
@ -1,6 +0,0 @@
|
||||
{
|
||||
"name": "docs",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
@ -8,7 +8,7 @@ from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.bedrock import BedrockBase
|
||||
from langchain.pydantic_v1 import Extra
|
||||
from langchain.schema.messages import AIMessage, BaseMessage
|
||||
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
|
||||
@ -48,10 +48,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError(
|
||||
"""Bedrock doesn't support stream requests at the moment."""
|
||||
provider = self._get_provider()
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
for chunk in self._prepare_input_and_invoke_stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
delta = chunk.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
|
||||
def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -70,18 +77,24 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
provider = self._get_provider()
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
completion = ""
|
||||
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
if self.streaming:
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
provider = self._get_provider()
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
|
||||
completion = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
||||
)
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
completion = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
||||
)
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
@ -1,11 +1,12 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
|
||||
|
||||
class LLMInputOutputAdapter:
|
||||
@ -15,6 +16,11 @@ class LLMInputOutputAdapter:
|
||||
It also provides helper function to extract
|
||||
the generated text from the model response."""
|
||||
|
||||
provider_to_output_key_map = {
|
||||
"anthropic": "completion",
|
||||
"amazon": "outputText",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def prepare_input(
|
||||
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
|
||||
@ -30,7 +36,7 @@ class LLMInputOutputAdapter:
|
||||
input_body["inputText"] = prompt
|
||||
|
||||
if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
|
||||
input_body["max_tokens_to_sample"] = 50
|
||||
input_body["max_tokens_to_sample"] = 256
|
||||
|
||||
return input_body
|
||||
|
||||
@ -47,6 +53,30 @@ class LLMInputOutputAdapter:
|
||||
else:
|
||||
return response_body.get("results")[0].get("outputText")
|
||||
|
||||
@classmethod
|
||||
def prepare_output_stream(
|
||||
cls, provider: str, response: Any, stop: Optional[List[str]] = None
|
||||
) -> Iterator[GenerationChunk]:
|
||||
stream = response.get("body")
|
||||
|
||||
if not stream:
|
||||
return
|
||||
|
||||
if provider not in cls.provider_to_output_key_map:
|
||||
raise ValueError(
|
||||
f"Unknown streaming response output key for provider: {provider}"
|
||||
)
|
||||
|
||||
for event in stream:
|
||||
chunk = event.get("chunk")
|
||||
if chunk:
|
||||
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||
|
||||
# chunk obj format varies with provider
|
||||
yield GenerationChunk(
|
||||
text=chunk_obj[cls.provider_to_output_key_map[provider]]
|
||||
)
|
||||
|
||||
|
||||
class BedrockBase(BaseModel, ABC):
|
||||
client: Any #: :meta private:
|
||||
@ -74,6 +104,15 @@ class BedrockBase(BaseModel, ABC):
|
||||
endpoint_url: Optional[str] = None
|
||||
"""Needed if you don't want to default to us-east-1 endpoint"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
provider_stop_sequence_key_name_map: Mapping[str, str] = {
|
||||
"anthropic": "stop_sequences",
|
||||
"amazon": "stopSequences",
|
||||
"ai21": "stop_sequences",
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
@ -154,6 +193,49 @@ class BedrockBase(BaseModel, ABC):
|
||||
|
||||
return text
|
||||
|
||||
def _prepare_input_and_invoke_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
provider = self._get_provider()
|
||||
|
||||
if stop:
|
||||
if provider not in self.provider_stop_sequence_key_name_map:
|
||||
raise ValueError(
|
||||
f"Stop sequence key name for {provider} is not supported."
|
||||
)
|
||||
|
||||
# stop sequence from _generate() overrides
|
||||
# stop sequences in the class attribute
|
||||
_model_kwargs[
|
||||
self.provider_stop_sequence_key_name_map.get(provider),
|
||||
] = stop
|
||||
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=self.model_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
|
||||
for chunk in LLMInputOutputAdapter.prepare_output_stream(
|
||||
provider, response, stop
|
||||
):
|
||||
yield chunk
|
||||
if run_manager is not None:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
|
||||
class Bedrock(LLM, BedrockBase):
|
||||
"""Bedrock models.
|
||||
@ -177,7 +259,8 @@ class Bedrock(LLM, BedrockBase):
|
||||
|
||||
llm = BedrockLLM(
|
||||
credentials_profile_name="default",
|
||||
model_id="amazon.titan-tg1-large"
|
||||
model_id="amazon.titan-tg1-large",
|
||||
streaming=True
|
||||
)
|
||||
|
||||
"""
|
||||
@ -192,6 +275,33 @@ class Bedrock(LLM, BedrockBase):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Call out to Bedrock service with streaming.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to pass into the model
|
||||
stop (Optional[List[str]], optional): Stop sequences. These will
|
||||
override any stop sequences in the `model_kwargs` attribute.
|
||||
Defaults to None.
|
||||
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
|
||||
run managers used to process the output. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Iterator[GenerationChunk]: Generator that yields the streamed responses.
|
||||
|
||||
Yields:
|
||||
Iterator[GenerationChunk]: Responses from the model.
|
||||
"""
|
||||
return self._prepare_input_and_invoke_stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -211,9 +321,15 @@ class Bedrock(LLM, BedrockBase):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = se("Tell me a joke.")
|
||||
response = llm("Tell me a joke.")
|
||||
"""
|
||||
|
||||
text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
return text
|
||||
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user