Community: Add mistral oss model support to azureml endpoints, plus configurable timeout (#19123)

- **Description:** There was no formatter for mistral models for Azure
ML endpoints. Adding that, plus a configurable timeout (it was hard
coded before)
- **Dependencies:** none
- **Twitter handle:** @tjaffri @docugami
This commit is contained in:
Taqi Jaffri
2024-03-18 21:10:42 -07:00
committed by GitHub
parent 07de4abe70
commit 044bc22acc
2 changed files with 62 additions and 2 deletions

View File

@@ -11,12 +11,18 @@ from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
DEFAULT_TIMEOUT = 50
class AzureMLEndpointClient(object):
"""AzureML Managed Endpoint client."""
def __init__(
self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = ""
self,
endpoint_url: str,
endpoint_api_key: str,
deployment_name: str = "",
timeout: int = DEFAULT_TIMEOUT,
) -> None:
"""Initialize the class."""
if not endpoint_api_key or not endpoint_url:
@@ -27,6 +33,7 @@ class AzureMLEndpointClient(object):
self.endpoint_url = endpoint_url
self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name
self.timeout = timeout
def call(
self,
@@ -47,7 +54,9 @@ class AzureMLEndpointClient(object):
headers["azureml-model-deployment"] = self.deployment_name
req = urllib.request.Request(self.endpoint_url, body, headers)
response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50))
response = urllib.request.urlopen(
req, timeout=kwargs.get("timeout", self.timeout)
)
result = response.read()
return result
@@ -334,6 +343,9 @@ class AzureMLBaseEndpoint(BaseModel):
"""Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`."""
timeout: int = DEFAULT_TIMEOUT
"""Request timeout for calls to the endpoint"""
http_client: Any = None #: :meta private:
content_formatter: Any = None
@@ -361,6 +373,12 @@ class AzureMLBaseEndpoint(BaseModel):
"AZUREML_ENDPOINT_API_TYPE",
AzureMLEndpointApiType.realtime,
)
values["timeout"] = get_from_dict_or_env(
values,
"timeout",
"AZUREML_TIMEOUT",
str(DEFAULT_TIMEOUT),
)
return values
@@ -424,12 +442,15 @@ class AzureMLBaseEndpoint(BaseModel):
endpoint_url = values.get("endpoint_url")
endpoint_key = values.get("endpoint_api_key")
deployment_name = values.get("deployment_name")
timeout = values.get("timeout", DEFAULT_TIMEOUT)
http_client = AzureMLEndpointClient(
endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore
timeout, # type: ignore
)
return http_client
@@ -442,6 +463,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime,
endpoint_api_key="my-api-key",
timeout=120,
content_formatter=content_formatter,
)
""" # noqa: E501