community[patch]: Support Streaming in Azure Machine Learning (#18246)

- [x] **PR title**: "community: Support streaming in Azure ML and few
naming changes"

- [x] **PR message**:
- **Description:** Added support for streaming for azureml_endpoint.
Also, renamed and AzureMLEndpointApiType.realtime to
AzureMLEndpointApiType.dedicated. Also, added new classes
CustomOpenAIChatContentFormatter and CustomOpenAIContentFormatter and
updated the classes LlamaChatContentFormatter and LlamaContentFormatter
to now show a deprecated warning message when instantiated.

---------

Co-authored-by: Sachin Paryani <saparan@microsoft.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Sachin Paryani
2024-03-28 16:38:20 -07:00
committed by GitHub
parent ecb11a4a32
commit 25c9f3d1d1
6 changed files with 285 additions and 76 deletions

View File

@@ -62,12 +62,14 @@ class AzureMLEndpointClient(object):
class AzureMLEndpointApiType(str, Enum):
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted
infrastructure, or `serverless` for models deployed as a service with a
"""Azure ML endpoints API types. Use `dedicated` for models deployed in hosted
infrastructure (also known as Online Endpoints in Azure Machine Learning),
or `serverless` for models deployed as a service with a
pay-as-you-go billing or PTU.
"""
realtime = "realtime"
dedicated = "dedicated"
realtime = "realtime" #: Deprecated
serverless = "serverless"
@@ -141,13 +143,13 @@ class ContentFormatterBase:
deploying models using different hosting methods. Each method may have
a different API structure."""
return [AzureMLEndpointApiType.realtime]
return [AzureMLEndpointApiType.dedicated]
def format_request_payload(
self,
prompt: str,
model_kwargs: Dict,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> Any:
"""Formats the request body according to the input schema of
the model. Returns bytes or seekable file like object in the
@@ -159,7 +161,7 @@ class ContentFormatterBase:
def format_response_payload(
self,
output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> Generation:
"""Formats the response body according to the output
schema of the model. Returns the data type that is
@@ -172,7 +174,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
return [AzureMLEndpointApiType.dedicated]
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
@@ -214,7 +216,7 @@ class HFContentFormatter(ContentFormatterBase):
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
return [AzureMLEndpointApiType.dedicated]
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
@@ -240,7 +242,7 @@ class DollyContentFormatter(ContentFormatterBase):
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
return [AzureMLEndpointApiType.dedicated]
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
@@ -264,19 +266,22 @@ class DollyContentFormatter(ContentFormatterBase):
return Generation(text=choice)
class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for LLaMa"""
class CustomOpenAIContentFormatter(ContentFormatterBase):
"""Content formatter for models that use the OpenAI like API scheme."""
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
"""Formats the request according to the chosen api"""
prompt = ContentFormatterBase.escape_special_characters(prompt)
if api_type == AzureMLEndpointApiType.realtime:
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
request_payload = json.dumps(
{
"input_data": {
@@ -297,7 +302,10 @@ class LlamaContentFormatter(ContentFormatterBase):
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
"""Formats response"""
if api_type == AzureMLEndpointApiType.realtime:
if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
@@ -324,6 +332,22 @@ class LlamaContentFormatter(ContentFormatterBase):
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class LlamaContentFormatter(CustomOpenAIContentFormatter):
"""Deprecated: Kept for backwards compatibility
Content formatter for Llama."""
content_formatter: Any = None
def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`LlamaContentFormatter` will be deprecated in the future.
Please use `CustomOpenAIContentFormatter` instead.
"""
)
class AzureMLBaseEndpoint(BaseModel):
"""Azure ML Online Endpoint models."""
@@ -331,9 +355,9 @@ class AzureMLBaseEndpoint(BaseModel):
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated
"""Type of the endpoint being consumed. Possible values are `serverless` for
pay-as-you-go and `realtime` for real-time endpoints. """
pay-as-you-go and `dedicated` for dedicated endpoints. """
endpoint_api_key: SecretStr = convert_to_secret_str("")
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
@@ -348,6 +372,8 @@ class AzureMLBaseEndpoint(BaseModel):
http_client: Any = None #: :meta private:
max_retries: int = 1
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
@@ -371,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
values,
"endpoint_api_type",
"AZUREML_ENDPOINT_API_TYPE",
AzureMLEndpointApiType.realtime,
AzureMLEndpointApiType.dedicated,
)
values["timeout"] = get_from_dict_or_env(
values,
@@ -404,7 +430,7 @@ class AzureMLBaseEndpoint(BaseModel):
if field_value.endswith("inference.ml.azure.com"):
raise ValueError(
"`endpoint_url` should contain the full invocation URL including "
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` "
"`/score` for `endpoint_api_type='dedicated'` or `/v1/completions` "
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
)
return field_value
@@ -415,11 +441,15 @@ class AzureMLBaseEndpoint(BaseModel):
) -> AzureMLEndpointApiType:
"""Validate that endpoint api type is compatible with the URL format."""
endpoint_url = values.get("endpoint_url")
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
"/score"
if (
(
field_value == AzureMLEndpointApiType.dedicated
or field_value == AzureMLEndpointApiType.realtime
)
and not endpoint_url.endswith("/score") # type: ignore[union-attr]
):
raise ValueError(
"Endpoints of type `realtime` should follow the format "
"Endpoints of type `dedicated` should follow the format "
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
" If your endpoint URL ends with `/v1/completions` or"
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
@@ -461,7 +491,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
.. code-block:: python
azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime,
endpoint_api_type=AzureMLApiType.dedicated,
endpoint_api_key="my-api-key",
timeout=120,
content_formatter=content_formatter,