mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
community[patch]: Support Streaming in Azure Machine Learning (#18246)
- [x] **PR title**: "community: Support streaming in Azure ML and few naming changes" - [x] **PR message**: - **Description:** Added support for streaming for azureml_endpoint. Also, renamed and AzureMLEndpointApiType.realtime to AzureMLEndpointApiType.dedicated. Also, added new classes CustomOpenAIChatContentFormatter and CustomOpenAIContentFormatter and updated the classes LlamaChatContentFormatter and LlamaContentFormatter to now show a deprecated warning message when instantiated. --------- Co-authored-by: Sachin Paryani <saparan@microsoft.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -62,12 +62,14 @@ class AzureMLEndpointClient(object):
|
||||
|
||||
|
||||
class AzureMLEndpointApiType(str, Enum):
|
||||
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted
|
||||
infrastructure, or `serverless` for models deployed as a service with a
|
||||
"""Azure ML endpoints API types. Use `dedicated` for models deployed in hosted
|
||||
infrastructure (also known as Online Endpoints in Azure Machine Learning),
|
||||
or `serverless` for models deployed as a service with a
|
||||
pay-as-you-go billing or PTU.
|
||||
"""
|
||||
|
||||
realtime = "realtime"
|
||||
dedicated = "dedicated"
|
||||
realtime = "realtime" #: Deprecated
|
||||
serverless = "serverless"
|
||||
|
||||
|
||||
@@ -141,13 +143,13 @@ class ContentFormatterBase:
|
||||
deploying models using different hosting methods. Each method may have
|
||||
a different API structure."""
|
||||
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
return [AzureMLEndpointApiType.dedicated]
|
||||
|
||||
def format_request_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
||||
) -> Any:
|
||||
"""Formats the request body according to the input schema of
|
||||
the model. Returns bytes or seekable file like object in the
|
||||
@@ -159,7 +161,7 @@ class ContentFormatterBase:
|
||||
def format_response_payload(
|
||||
self,
|
||||
output: bytes,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
||||
) -> Generation:
|
||||
"""Formats the response body according to the output
|
||||
schema of the model. Returns the data type that is
|
||||
@@ -172,7 +174,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
return [AzureMLEndpointApiType.dedicated]
|
||||
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
@@ -214,7 +216,7 @@ class HFContentFormatter(ContentFormatterBase):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
return [AzureMLEndpointApiType.dedicated]
|
||||
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
@@ -240,7 +242,7 @@ class DollyContentFormatter(ContentFormatterBase):
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
return [AzureMLEndpointApiType.dedicated]
|
||||
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
@@ -264,19 +266,22 @@ class DollyContentFormatter(ContentFormatterBase):
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for LLaMa"""
|
||||
class CustomOpenAIContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for models that use the OpenAI like API scheme."""
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
if api_type in [
|
||||
AzureMLEndpointApiType.dedicated,
|
||||
AzureMLEndpointApiType.realtime,
|
||||
]:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
@@ -297,7 +302,10 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
"""Formats response"""
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
if api_type in [
|
||||
AzureMLEndpointApiType.dedicated,
|
||||
AzureMLEndpointApiType.realtime,
|
||||
]:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
@@ -324,6 +332,22 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class LlamaContentFormatter(CustomOpenAIContentFormatter):
|
||||
"""Deprecated: Kept for backwards compatibility
|
||||
|
||||
Content formatter for Llama."""
|
||||
|
||||
content_formatter: Any = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"""`LlamaContentFormatter` will be deprecated in the future.
|
||||
Please use `CustomOpenAIContentFormatter` instead.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class AzureMLBaseEndpoint(BaseModel):
|
||||
"""Azure ML Online Endpoint models."""
|
||||
|
||||
@@ -331,9 +355,9 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime
|
||||
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated
|
||||
"""Type of the endpoint being consumed. Possible values are `serverless` for
|
||||
pay-as-you-go and `realtime` for real-time endpoints. """
|
||||
pay-as-you-go and `dedicated` for dedicated endpoints. """
|
||||
|
||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
@@ -348,6 +372,8 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
|
||||
http_client: Any = None #: :meta private:
|
||||
|
||||
max_retries: int = 1
|
||||
|
||||
content_formatter: Any = None
|
||||
"""The content formatter that provides an input and output
|
||||
transform function to handle formats between the LLM and
|
||||
@@ -371,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
values,
|
||||
"endpoint_api_type",
|
||||
"AZUREML_ENDPOINT_API_TYPE",
|
||||
AzureMLEndpointApiType.realtime,
|
||||
AzureMLEndpointApiType.dedicated,
|
||||
)
|
||||
values["timeout"] = get_from_dict_or_env(
|
||||
values,
|
||||
@@ -404,7 +430,7 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
if field_value.endswith("inference.ml.azure.com"):
|
||||
raise ValueError(
|
||||
"`endpoint_url` should contain the full invocation URL including "
|
||||
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` "
|
||||
"`/score` for `endpoint_api_type='dedicated'` or `/v1/completions` "
|
||||
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
|
||||
)
|
||||
return field_value
|
||||
@@ -415,11 +441,15 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
) -> AzureMLEndpointApiType:
|
||||
"""Validate that endpoint api type is compatible with the URL format."""
|
||||
endpoint_url = values.get("endpoint_url")
|
||||
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
|
||||
"/score"
|
||||
if (
|
||||
(
|
||||
field_value == AzureMLEndpointApiType.dedicated
|
||||
or field_value == AzureMLEndpointApiType.realtime
|
||||
)
|
||||
and not endpoint_url.endswith("/score") # type: ignore[union-attr]
|
||||
):
|
||||
raise ValueError(
|
||||
"Endpoints of type `realtime` should follow the format "
|
||||
"Endpoints of type `dedicated` should follow the format "
|
||||
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
|
||||
" If your endpoint URL ends with `/v1/completions` or"
|
||||
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
||||
@@ -461,7 +491,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
|
||||
.. code-block:: python
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_type=AzureMLApiType.realtime,
|
||||
endpoint_api_type=AzureMLApiType.dedicated,
|
||||
endpoint_api_key="my-api-key",
|
||||
timeout=120,
|
||||
content_formatter=content_formatter,
|
||||
|
Reference in New Issue
Block a user