mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 06:00:41 +00:00
Added Streaming Capability to SageMaker LLMs (#10535)
This PR adds the ability to declare a Streaming response in the SageMaker LLM by leveraging the `invoke_endpoint_with_response_stream` capability in `boto3`. It is heavily based on the AWS Blog Post announcement linked [here](https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/). It does not add any additional dependencies since it uses the existing `boto3` version. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d9670a5945
commit
4236ae3851
@ -1,6 +1,8 @@
|
||||
"""Sagemaker InvokeEndpoint API."""
|
||||
import io
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
|
||||
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@ -8,7 +10,66 @@ from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
|
||||
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
||||
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
|
||||
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
|
||||
|
||||
|
||||
class LineIterator:
|
||||
"""
|
||||
A helper class for parsing the byte stream input.
|
||||
|
||||
The output of the model will be in the following format:
|
||||
|
||||
b'{"outputs": [" a"]}\n'
|
||||
b'{"outputs": [" challenging"]}\n'
|
||||
b'{"outputs": [" problem"]}\n'
|
||||
...
|
||||
|
||||
While usually each PayloadPart event from the event stream will
|
||||
contain a byte array with a full json, this is not guaranteed
|
||||
and some of the json objects may be split acrossPayloadPart events.
|
||||
|
||||
For example:
|
||||
|
||||
{'PayloadPart': {'Bytes': b'{"outputs": '}}
|
||||
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
|
||||
|
||||
|
||||
This class accounts for this by concatenating bytes written via the 'write' function
|
||||
and then exposing a method which will return lines (ending with a '\n' character)
|
||||
within the buffer via the 'scan_lines' function.
|
||||
It maintains the position of the last read position to ensure
|
||||
that previous bytes are not exposed again.
|
||||
|
||||
For more details see:
|
||||
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
|
||||
"""
|
||||
|
||||
def __init__(self, stream: Any) -> None:
|
||||
self.byte_iterator = iter(stream)
|
||||
self.buffer = io.BytesIO()
|
||||
self.read_pos = 0
|
||||
|
||||
def __iter__(self) -> "LineIterator":
|
||||
return self
|
||||
|
||||
def __next__(self) -> Any:
|
||||
while True:
|
||||
self.buffer.seek(self.read_pos)
|
||||
line = self.buffer.readline()
|
||||
if line and line[-1] == ord("\n"):
|
||||
self.read_pos += len(line)
|
||||
return line[:-1]
|
||||
try:
|
||||
chunk = next(self.byte_iterator)
|
||||
except StopIteration:
|
||||
if self.read_pos < self.buffer.getbuffer().nbytes:
|
||||
continue
|
||||
raise
|
||||
if "PayloadPart" not in chunk:
|
||||
# Unknown Event Type
|
||||
continue
|
||||
self.buffer.seek(0, io.SEEK_END)
|
||||
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||
|
||||
|
||||
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
|
||||
@ -151,6 +212,9 @@ class SagemakerEndpoint(LLM):
|
||||
and the endpoint.
|
||||
"""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
"""
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -264,22 +328,43 @@ class SagemakerEndpoint(LLM):
|
||||
content_type = self.content_handler.content_type
|
||||
accepts = self.content_handler.accepts
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = self.client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=body,
|
||||
ContentType=content_type,
|
||||
Accept=accepts,
|
||||
**_endpoint_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
if self.streaming and run_manager:
|
||||
try:
|
||||
resp = self.client.invoke_endpoint_with_response_stream(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=body,
|
||||
ContentType=self.content_handler.content_type,
|
||||
**_endpoint_kwargs,
|
||||
)
|
||||
iterator = LineIterator(resp["Body"])
|
||||
current_completion: str = ""
|
||||
for line in iterator:
|
||||
resp = json.loads(line)
|
||||
resp_output = resp.get("outputs")[0]
|
||||
if stop is not None:
|
||||
# Uses same approach as below
|
||||
resp_output = enforce_stop_tokens(resp_output, stop)
|
||||
current_completion += resp_output
|
||||
run_manager.on_llm_new_token(resp_output)
|
||||
return current_completion
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
|
||||
else:
|
||||
try:
|
||||
response = self.client.invoke_endpoint(
|
||||
EndpointName=self.endpoint_name,
|
||||
Body=body,
|
||||
ContentType=content_type,
|
||||
Accept=accepts,
|
||||
**_endpoint_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
text = self.content_handler.transform_output(response["Body"])
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to the sagemaker endpoint.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
text = self.content_handler.transform_output(response["Body"])
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to the sagemaker endpoint.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
return text
|
||||
|
Loading…
Reference in New Issue
Block a user