mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 19:03:25 +00:00
Amazon Bedrock Support Streaming (#10393)
### Description - Add support for streaming with `Bedrock` LLM and `BedrockChat` Chat Model. - Bedrock as of now supports streaming for the `anthropic.claude-*` and `amazon.titan-*` models only, hence support for those have been built. - Also increased the default `max_token_to_sample` for Bedrock `anthropic` model provider to `256` from `50` to keep in line with the `Anthropic` defaults. - Added examples for streaming responses to the bedrock example notebooks. **_NOTE:_**: This PR fixes the issues mentioned in #9897 and makes that PR redundant.
This commit is contained in:
parent
0749a642f5
commit
67c5950df3
@ -22,7 +22,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -73,13 +73,46 @@
|
|||||||
"chat(messages)"
|
"chat(messages)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a4a4f4d4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### For BedrockChat with Streaming"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "c253883f",
|
"id": "c253883f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": [
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||||
|
"\n",
|
||||||
|
"chat = BedrockChat(\n",
|
||||||
|
" model_id=\"anthropic.claude-v2\",\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" callbacks=[StreamingStdOutCallbackHandler()],\n",
|
||||||
|
" model_kwargs={\"temperature\": 0.1},\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d9e52838",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -98,7 +131,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.10.9"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -61,6 +61,46 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"conversation.predict(input=\"Hi there!\")"
|
"conversation.predict(input=\"Hi there!\")"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Conversation Chain With Streaming"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import Bedrock\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"llm = Bedrock(\n",
|
||||||
|
" credentials_profile_name=\"bedrock-admin\",\n",
|
||||||
|
" model_id=\"amazon.titan-tg1-large\",\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
" callbacks=[StreamingStdOutCallbackHandler()],\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"conversation = ConversationChain(\n",
|
||||||
|
" llm=llm, verbose=True, memory=ConversationBufferMemory()\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"conversation.predict(input=\"Hi there!\")"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
6
docs/package-lock.json
generated
6
docs/package-lock.json
generated
@ -1,6 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "docs",
|
|
||||||
"lockfileVersion": 3,
|
|
||||||
"requires": true,
|
|
||||||
"packages": {}
|
|
||||||
}
|
|
@ -8,7 +8,7 @@ from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
|||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.llms.bedrock import BedrockBase
|
from langchain.llms.bedrock import BedrockBase
|
||||||
from langchain.pydantic_v1 import Extra
|
from langchain.pydantic_v1 import Extra
|
||||||
from langchain.schema.messages import AIMessage, BaseMessage
|
from langchain.schema.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
|
||||||
@ -48,10 +48,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
raise NotImplementedError(
|
provider = self._get_provider()
|
||||||
"""Bedrock doesn't support stream requests at the moment."""
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||||
|
provider=provider, messages=messages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for chunk in self._prepare_input_and_invoke_stream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
delta = chunk.text
|
||||||
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||||
|
|
||||||
def _astream(
|
def _astream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -70,6 +77,12 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
|
completion = ""
|
||||||
|
|
||||||
|
if self.streaming:
|
||||||
|
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||||
|
completion += chunk.text
|
||||||
|
else:
|
||||||
provider = self._get_provider()
|
provider = self._get_provider()
|
||||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||||
provider=provider, messages=messages
|
provider=provider, messages=messages
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
from langchain.schema.output import GenerationChunk
|
||||||
|
|
||||||
|
|
||||||
class LLMInputOutputAdapter:
|
class LLMInputOutputAdapter:
|
||||||
@ -15,6 +16,11 @@ class LLMInputOutputAdapter:
|
|||||||
It also provides helper function to extract
|
It also provides helper function to extract
|
||||||
the generated text from the model response."""
|
the generated text from the model response."""
|
||||||
|
|
||||||
|
provider_to_output_key_map = {
|
||||||
|
"anthropic": "completion",
|
||||||
|
"amazon": "outputText",
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare_input(
|
def prepare_input(
|
||||||
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
|
cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
|
||||||
@ -30,7 +36,7 @@ class LLMInputOutputAdapter:
|
|||||||
input_body["inputText"] = prompt
|
input_body["inputText"] = prompt
|
||||||
|
|
||||||
if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
|
if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
|
||||||
input_body["max_tokens_to_sample"] = 50
|
input_body["max_tokens_to_sample"] = 256
|
||||||
|
|
||||||
return input_body
|
return input_body
|
||||||
|
|
||||||
@ -47,6 +53,30 @@ class LLMInputOutputAdapter:
|
|||||||
else:
|
else:
|
||||||
return response_body.get("results")[0].get("outputText")
|
return response_body.get("results")[0].get("outputText")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def prepare_output_stream(
|
||||||
|
cls, provider: str, response: Any, stop: Optional[List[str]] = None
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
stream = response.get("body")
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
return
|
||||||
|
|
||||||
|
if provider not in cls.provider_to_output_key_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown streaming response output key for provider: {provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in stream:
|
||||||
|
chunk = event.get("chunk")
|
||||||
|
if chunk:
|
||||||
|
chunk_obj = json.loads(chunk.get("bytes").decode())
|
||||||
|
|
||||||
|
# chunk obj format varies with provider
|
||||||
|
yield GenerationChunk(
|
||||||
|
text=chunk_obj[cls.provider_to_output_key_map[provider]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BedrockBase(BaseModel, ABC):
|
class BedrockBase(BaseModel, ABC):
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
@ -74,6 +104,15 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
endpoint_url: Optional[str] = None
|
endpoint_url: Optional[str] = None
|
||||||
"""Needed if you don't want to default to us-east-1 endpoint"""
|
"""Needed if you don't want to default to us-east-1 endpoint"""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
|
provider_stop_sequence_key_name_map: Mapping[str, str] = {
|
||||||
|
"anthropic": "stop_sequences",
|
||||||
|
"amazon": "stopSequences",
|
||||||
|
"ai21": "stop_sequences",
|
||||||
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that AWS credentials to and python package exists in environment."""
|
"""Validate that AWS credentials to and python package exists in environment."""
|
||||||
@ -154,6 +193,49 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def _prepare_input_and_invoke_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
provider = self._get_provider()
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
if provider not in self.provider_stop_sequence_key_name_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Stop sequence key name for {provider} is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
# stop sequence from _generate() overrides
|
||||||
|
# stop sequences in the class attribute
|
||||||
|
_model_kwargs[
|
||||||
|
self.provider_stop_sequence_key_name_map.get(provider),
|
||||||
|
] = stop
|
||||||
|
|
||||||
|
params = {**_model_kwargs, **kwargs}
|
||||||
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.invoke_model_with_response_stream(
|
||||||
|
body=body,
|
||||||
|
modelId=self.model_id,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||||
|
|
||||||
|
for chunk in LLMInputOutputAdapter.prepare_output_stream(
|
||||||
|
provider, response, stop
|
||||||
|
):
|
||||||
|
yield chunk
|
||||||
|
if run_manager is not None:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
|
|
||||||
class Bedrock(LLM, BedrockBase):
|
class Bedrock(LLM, BedrockBase):
|
||||||
"""Bedrock models.
|
"""Bedrock models.
|
||||||
@ -177,7 +259,8 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
|
|
||||||
llm = BedrockLLM(
|
llm = BedrockLLM(
|
||||||
credentials_profile_name="default",
|
credentials_profile_name="default",
|
||||||
model_id="amazon.titan-tg1-large"
|
model_id="amazon.titan-tg1-large",
|
||||||
|
streaming=True
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -192,6 +275,33 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
"""Call out to Bedrock service with streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to pass into the model
|
||||||
|
stop (Optional[List[str]], optional): Stop sequences. These will
|
||||||
|
override any stop sequences in the `model_kwargs` attribute.
|
||||||
|
Defaults to None.
|
||||||
|
run_manager (Optional[CallbackManagerForLLMRun], optional): Callback
|
||||||
|
run managers used to process the output. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator[GenerationChunk]: Generator that yields the streamed responses.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[GenerationChunk]: Responses from the model.
|
||||||
|
"""
|
||||||
|
return self._prepare_input_and_invoke_stream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -211,9 +321,15 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
response = se("Tell me a joke.")
|
response = llm("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text = self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
|
||||||
return text
|
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user