Added Streaming Capability to SageMaker LLMs (#10535)

This PR adds the ability to declare a Streaming response in the
SageMaker LLM by leveraging the `invoke_endpoint_with_response_stream`
capability in `boto3`. It is heavily based on the AWS Blog Post
announcement linked
[here](https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/).

It does not add any additional dependencies since it uses the existing
`boto3` version.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Juan Daza 2023-10-06 02:08:43 +08:00 committed by GitHub
parent d9670a5945
commit 4236ae3851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,8 @@
"""Sagemaker InvokeEndpoint API.""" """Sagemaker InvokeEndpoint API."""
import io
import json
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM from langchain.llms.base import LLM
@ -8,7 +10,66 @@ from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator from langchain.pydantic_v1 import Extra, root_validator
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]]) INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]]) OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
class LineIterator:
"""
A helper class for parsing the byte stream input.
The output of the model will be in the following format:
b'{"outputs": [" a"]}\n'
b'{"outputs": [" challenging"]}\n'
b'{"outputs": [" problem"]}\n'
...
While usually each PayloadPart event from the event stream will
contain a byte array with a full json, this is not guaranteed
and some of the json objects may be split acrossPayloadPart events.
For example:
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character)
within the buffer via the 'scan_lines' function.
It maintains the position of the last read position to ensure
that previous bytes are not exposed again.
For more details see:
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
"""
def __init__(self, stream: Any) -> None:
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self) -> "LineIterator":
return self
def __next__(self) -> Any:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
# Unknown Event Type
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
@ -151,6 +212,9 @@ class SagemakerEndpoint(LLM):
and the endpoint. and the endpoint.
""" """
streaming: bool = False
"""Whether to stream the results."""
""" """
Example: Example:
.. code-block:: python .. code-block:: python
@ -264,7 +328,28 @@ class SagemakerEndpoint(LLM):
content_type = self.content_handler.content_type content_type = self.content_handler.content_type
accepts = self.content_handler.accepts accepts = self.content_handler.accepts
# send request if self.streaming and run_manager:
try:
resp = self.client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=body,
ContentType=self.content_handler.content_type,
**_endpoint_kwargs,
)
iterator = LineIterator(resp["Body"])
current_completion: str = ""
for line in iterator:
resp = json.loads(line)
resp_output = resp.get("outputs")[0]
if stop is not None:
# Uses same approach as below
resp_output = enforce_stop_tokens(resp_output, stop)
current_completion += resp_output
run_manager.on_llm_new_token(resp_output)
return current_completion
except Exception as e:
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
else:
try: try:
response = self.client.invoke_endpoint( response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name, EndpointName=self.endpoint_name,