mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 22:15:08 +00:00
Added Streaming Capability to SageMaker LLMs (#10535)
This PR adds the ability to declare a Streaming response in the SageMaker LLM by leveraging the `invoke_endpoint_with_response_stream` capability in `boto3`. It is heavily based on the AWS Blog Post announcement linked [here](https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/). It does not add any additional dependencies since it uses the existing `boto3` version. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d9670a5945
commit
4236ae3851
@ -1,6 +1,8 @@
|
|||||||
"""Sagemaker InvokeEndpoint API."""
|
"""Sagemaker InvokeEndpoint API."""
|
||||||
|
import io
|
||||||
|
import json
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union
|
from typing import Any, Dict, Generic, Iterator, List, Mapping, Optional, TypeVar, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
@ -8,7 +10,66 @@ from langchain.llms.utils import enforce_stop_tokens
|
|||||||
from langchain.pydantic_v1 import Extra, root_validator
|
from langchain.pydantic_v1 import Extra, root_validator
|
||||||
|
|
||||||
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
INPUT_TYPE = TypeVar("INPUT_TYPE", bound=Union[str, List[str]])
|
||||||
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]]])
|
OUTPUT_TYPE = TypeVar("OUTPUT_TYPE", bound=Union[str, List[List[float]], Iterator])
|
||||||
|
|
||||||
|
|
||||||
|
class LineIterator:
|
||||||
|
"""
|
||||||
|
A helper class for parsing the byte stream input.
|
||||||
|
|
||||||
|
The output of the model will be in the following format:
|
||||||
|
|
||||||
|
b'{"outputs": [" a"]}\n'
|
||||||
|
b'{"outputs": [" challenging"]}\n'
|
||||||
|
b'{"outputs": [" problem"]}\n'
|
||||||
|
...
|
||||||
|
|
||||||
|
While usually each PayloadPart event from the event stream will
|
||||||
|
contain a byte array with a full json, this is not guaranteed
|
||||||
|
and some of the json objects may be split acrossPayloadPart events.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
{'PayloadPart': {'Bytes': b'{"outputs": '}}
|
||||||
|
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
|
||||||
|
|
||||||
|
|
||||||
|
This class accounts for this by concatenating bytes written via the 'write' function
|
||||||
|
and then exposing a method which will return lines (ending with a '\n' character)
|
||||||
|
within the buffer via the 'scan_lines' function.
|
||||||
|
It maintains the position of the last read position to ensure
|
||||||
|
that previous bytes are not exposed again.
|
||||||
|
|
||||||
|
For more details see:
|
||||||
|
https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stream: Any) -> None:
|
||||||
|
self.byte_iterator = iter(stream)
|
||||||
|
self.buffer = io.BytesIO()
|
||||||
|
self.read_pos = 0
|
||||||
|
|
||||||
|
def __iter__(self) -> "LineIterator":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> Any:
|
||||||
|
while True:
|
||||||
|
self.buffer.seek(self.read_pos)
|
||||||
|
line = self.buffer.readline()
|
||||||
|
if line and line[-1] == ord("\n"):
|
||||||
|
self.read_pos += len(line)
|
||||||
|
return line[:-1]
|
||||||
|
try:
|
||||||
|
chunk = next(self.byte_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
if self.read_pos < self.buffer.getbuffer().nbytes:
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
if "PayloadPart" not in chunk:
|
||||||
|
# Unknown Event Type
|
||||||
|
continue
|
||||||
|
self.buffer.seek(0, io.SEEK_END)
|
||||||
|
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||||
|
|
||||||
|
|
||||||
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
|
class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]):
|
||||||
@ -151,6 +212,9 @@ class SagemakerEndpoint(LLM):
|
|||||||
and the endpoint.
|
and the endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -264,7 +328,28 @@ class SagemakerEndpoint(LLM):
|
|||||||
content_type = self.content_handler.content_type
|
content_type = self.content_handler.content_type
|
||||||
accepts = self.content_handler.accepts
|
accepts = self.content_handler.accepts
|
||||||
|
|
||||||
# send request
|
if self.streaming and run_manager:
|
||||||
|
try:
|
||||||
|
resp = self.client.invoke_endpoint_with_response_stream(
|
||||||
|
EndpointName=self.endpoint_name,
|
||||||
|
Body=body,
|
||||||
|
ContentType=self.content_handler.content_type,
|
||||||
|
**_endpoint_kwargs,
|
||||||
|
)
|
||||||
|
iterator = LineIterator(resp["Body"])
|
||||||
|
current_completion: str = ""
|
||||||
|
for line in iterator:
|
||||||
|
resp = json.loads(line)
|
||||||
|
resp_output = resp.get("outputs")[0]
|
||||||
|
if stop is not None:
|
||||||
|
# Uses same approach as below
|
||||||
|
resp_output = enforce_stop_tokens(resp_output, stop)
|
||||||
|
current_completion += resp_output
|
||||||
|
run_manager.on_llm_new_token(resp_output)
|
||||||
|
return current_completion
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error raised by streaming inference endpoint: {e}")
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
response = self.client.invoke_endpoint(
|
response = self.client.invoke_endpoint(
|
||||||
EndpointName=self.endpoint_name,
|
EndpointName=self.endpoint_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user