mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 15:04:13 +00:00
feat: adding paygo api support for Azure ML / Azure AI Studio (#14560)
- **Description:** Introducing support for LLMs and Chat models running in Azure AI studio and Azure ML using the new deployment mode pay-as-you-go (model as a service). - **Issue:** NA - **Dependencies:** None. - **Tag maintainer:** @prakharg-msft @gdyre - **Twitter handle:** @santiagofacundo Examples added: * [docs/docs/integrations/llms/azure_ml.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_endpoint.ipynb) * [docs/docs/integrations/chat/azureml_chat_endpoint.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -10,16 +10,24 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr, validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain_community.llms.azureml_endpoint import (
|
||||
AzureMLEndpointClient,
|
||||
AzureMLBaseEndpoint,
|
||||
AzureMLEndpointApiType,
|
||||
ContentFormatterBase,
|
||||
)
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
def __init__(self):
|
||||
raise TypeError(
|
||||
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||
"`LlamaChatContentFormatter` instead."
|
||||
)
|
||||
|
||||
|
||||
class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for `LLaMA`."""
|
||||
|
||||
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
||||
@@ -45,7 +53,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
}
|
||||
elif (
|
||||
isinstance(message, ChatMessage)
|
||||
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
|
||||
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
|
||||
):
|
||||
return {
|
||||
"role": message.role,
|
||||
@@ -53,79 +61,96 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
}
|
||||
else:
|
||||
supported = ",".join(
|
||||
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
|
||||
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
|
||||
)
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||
)
|
||||
|
||||
def _format_request_payload(
|
||||
self, messages: List[BaseMessage], model_kwargs: Dict
|
||||
) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_request_payload(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType,
|
||||
) -> str:
|
||||
"""Formats the request according to the chosen api"""
|
||||
chat_messages = [
|
||||
LlamaContentFormatter._convert_message_to_dict(message)
|
||||
LlamaChatContentFormatter._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
prompt = json.dumps(
|
||||
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
|
||||
)
|
||||
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": chat_messages,
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif api_type == AzureMLEndpointApiType.serverless:
|
||||
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
return str.encode(prompt)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> ChatGeneration:
|
||||
"""Formats response"""
|
||||
return json.loads(output)["output"]
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
try:
|
||||
choice = json.loads(output)["output"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
content=choice.strip(),
|
||||
type="assistant",
|
||||
),
|
||||
generation_info=None,
|
||||
)
|
||||
if api_type == AzureMLEndpointApiType.serverless:
|
||||
try:
|
||||
choice = json.loads(output)["choices"][0]
|
||||
if not isinstance(choice, dict):
|
||||
raise TypeError(
|
||||
"Endpoint response is not well formed for a chat "
|
||||
"model. Expected `dict` but `{type(choice)}` was received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
content=choice["message"]["content"].strip(),
|
||||
type=choice["message"]["role"],
|
||||
),
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
logprobs=choice.get("logprobs"),
|
||||
),
|
||||
)
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
"""`AzureML` Chat models API.
|
||||
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||
"""Azure ML Online Endpoint chat models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_chat = AzureMLChatOnlineEndpoint(
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_type=AzureMLApiType.realtime,
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
content_formatter=chat_content_formatter,
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint_url: str = ""
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
http_client: Any = None #: :meta private:
|
||||
|
||||
content_formatter: Any = None
|
||||
"""The content formatter that provides an input and output
|
||||
transform function to handle formats between the LLM and
|
||||
the endpoint"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
@validator("http_client", always=True, allow_reuse=True)
|
||||
@classmethod
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exist in environment."""
|
||||
values["endpoint_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
||||
)
|
||||
endpoint_url = get_from_dict_or_env(
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
http_client = AzureMLEndpointClient(
|
||||
endpoint_url, values["endpoint_api_key"].get_secret_value()
|
||||
)
|
||||
return http_client
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
@@ -140,13 +165,13 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
"""Return type of llm."""
|
||||
return "azureml_chat_endpoint"
|
||||
|
||||
def _call(
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> ChatResult:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
Args:
|
||||
messages: The messages in the conversation with the chat model.
|
||||
@@ -158,12 +183,17 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
response = azureml_model("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_model_kwargs.update(kwargs)
|
||||
if stop:
|
||||
_model_kwargs["stop"] = stop
|
||||
|
||||
request_payload = self.content_formatter._format_request_payload(
|
||||
messages, _model_kwargs
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
messages, _model_kwargs, self.endpoint_api_type
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
response_payload = self.http_client.call(
|
||||
body=request_payload, run_manager=run_manager
|
||||
)
|
||||
return generated_text
|
||||
generations = self.content_formatter.format_response_payload(
|
||||
response_payload, self.endpoint_api_type
|
||||
)
|
||||
return ChatResult(generations=[generations])
|
||||
|
@@ -2,12 +2,14 @@ import json
|
||||
import urllib.request
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
class AzureMLEndpointClient(object):
|
||||
@@ -26,7 +28,12 @@ class AzureMLEndpointClient(object):
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
self.deployment_name = deployment_name
|
||||
|
||||
def call(self, body: bytes, **kwargs: Any) -> bytes:
|
||||
def call(
|
||||
self,
|
||||
body: bytes,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> bytes:
|
||||
"""call."""
|
||||
|
||||
# The azureml-model-deployment header will force the request to go to a
|
||||
@@ -45,6 +52,16 @@ class AzureMLEndpointClient(object):
|
||||
return result
|
||||
|
||||
|
||||
class AzureMLEndpointApiType(str, Enum):
|
||||
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted
|
||||
infrastructure, or `serverless` for models deployed as a service with a
|
||||
pay-as-you-go billing or PTU.
|
||||
"""
|
||||
|
||||
realtime = "realtime"
|
||||
serverless = "serverless"
|
||||
|
||||
|
||||
class ContentFormatterBase:
|
||||
"""Transform request and response of AzureML endpoint to match with
|
||||
required schema.
|
||||
@@ -61,7 +78,8 @@ class ContentFormatterBase:
|
||||
def format_request_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model_kwargs: Dict
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType,
|
||||
) -> bytes:
|
||||
input_str = json.dumps(
|
||||
{
|
||||
@@ -71,7 +89,9 @@ class ContentFormatterBase:
|
||||
)
|
||||
return str.encode(input_str)
|
||||
|
||||
def format_response_payload(self, output: str) -> str:
|
||||
def format_response_payload(
|
||||
self, output: str, api_type: AzureMLEndpointApiType
|
||||
) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0]["0"]
|
||||
"""
|
||||
@@ -81,6 +101,12 @@ class ContentFormatterBase:
|
||||
accepts: Optional[str] = "application/json"
|
||||
"""The MIME type of the response data returned from the endpoint"""
|
||||
|
||||
format_error_msg: Optional[str] = (
|
||||
"Error while formatting response payload for chat model of type "
|
||||
" `{api_type}`. Are you using the right formatter for the deployed "
|
||||
" model and endpoint type?"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def escape_special_characters(prompt: str) -> str:
|
||||
"""Escapes any special characters in `prompt`"""
|
||||
@@ -100,15 +126,32 @@ class ContentFormatterBase:
|
||||
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
"""Supported APIs for the given formatter. Azure ML supports
|
||||
deploying models using different hosting methods. Each method may have
|
||||
a different API structure."""
|
||||
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
@abstractmethod
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
def format_request_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
) -> bytes:
|
||||
"""Formats the request body according to the input schema of
|
||||
the model. Returns bytes or seekable file like object in the
|
||||
format specified in the content_type request header.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
def format_response_payload(
|
||||
self,
|
||||
output: bytes,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
) -> Generation:
|
||||
"""Formats the response body according to the output
|
||||
schema of the model. Returns the data type that is
|
||||
received from the response.
|
||||
@@ -118,15 +161,27 @@ class ContentFormatterBase:
|
||||
class GPT2ContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for GPT2"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
return json.loads(output)[0]["0"]
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
class OSSContentFormatter(GPT2ContentFormatter):
|
||||
@@ -148,21 +203,39 @@ class OSSContentFormatter(GPT2ContentFormatter):
|
||||
class HFContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for LLMs from the HuggingFace catalog."""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
return json.loads(output)[0]["generated_text"]
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]["generated_text"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
class DollyContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for the Dolly-v2-12b model"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
@@ -172,49 +245,88 @@ class DollyContentFormatter(ContentFormatterBase):
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
return json.loads(output)[0]
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for LLaMa"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": [f'"{prompt}"'],
|
||||
"parameters": model_kwargs,
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": [f'"{prompt}"'],
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
elif api_type == AzureMLEndpointApiType.serverless:
|
||||
request_payload = json.dumps({"prompt": prompt, **model_kwargs})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
"""Formats response"""
|
||||
return json.loads(output)[0]["0"]
|
||||
|
||||
|
||||
class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
"""Azure ML Online Endpoint models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
if api_type == AzureMLEndpointApiType.serverless:
|
||||
try:
|
||||
choice = json.loads(output)["choices"][0]
|
||||
if not isinstance(choice, dict):
|
||||
raise TypeError(
|
||||
"Endpoint response is not well formed for a chat "
|
||||
"model. Expected `dict` but `{type(choice)}` was "
|
||||
"received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(
|
||||
text=choice["text"].strip(),
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
logprobs=choice.get("logprobs"),
|
||||
),
|
||||
)
|
||||
""" # noqa: E501
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class AzureMLBaseEndpoint(BaseModel):
|
||||
"""Azure ML Online Endpoint models."""
|
||||
|
||||
endpoint_url: str = ""
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_key: str = ""
|
||||
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime
|
||||
"""Type of the endpoint being consumed. Possible values are `serverless` for
|
||||
pay-as-you-go and `realtime` for real-time endpoints. """
|
||||
|
||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
@@ -232,22 +344,106 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
@validator("http_client", always=True, allow_reuse=True)
|
||||
@classmethod
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
endpoint_key = get_from_dict_or_env(
|
||||
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
|
||||
@root_validator(pre=True)
|
||||
def validate_environ(cls, values: Dict) -> Dict:
|
||||
values["endpoint_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
||||
)
|
||||
endpoint_url = get_from_dict_or_env(
|
||||
values["endpoint_url"] = get_from_dict_or_env(
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
deployment_name = get_from_dict_or_env(
|
||||
values["deployment_name"] = get_from_dict_or_env(
|
||||
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
|
||||
)
|
||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
|
||||
values["endpoint_api_type"] = get_from_dict_or_env(
|
||||
values,
|
||||
"endpoint_api_type",
|
||||
"AZUREML_ENDPOINT_API_TYPE",
|
||||
AzureMLEndpointApiType.realtime,
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@validator("content_formatter")
|
||||
def validate_content_formatter(
|
||||
cls, field_value: Any, values: Dict
|
||||
) -> ContentFormatterBase:
|
||||
"""Validate that content formatter is supported by endpoint type."""
|
||||
endpoint_api_type = values.get("endpoint_api_type")
|
||||
if endpoint_api_type not in field_value.supported_api_types:
|
||||
raise ValueError(
|
||||
f"Content formatter f{type(field_value)} is not supported by this "
|
||||
f"endpoint. Supported types are {field_value.supported_api_types} "
|
||||
f"but endpoint is {endpoint_api_type}."
|
||||
)
|
||||
return field_value
|
||||
|
||||
@validator("endpoint_url")
|
||||
def validate_endpoint_url(cls, field_value: Any) -> str:
|
||||
"""Validate that endpoint url is complete."""
|
||||
if field_value.endswith("/"):
|
||||
field_value = field_value[:-1]
|
||||
if field_value.endswith("inference.ml.azure.com"):
|
||||
raise ValueError(
|
||||
"`endpoint_url` should contain the full invocation URL including "
|
||||
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` "
|
||||
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
|
||||
)
|
||||
return field_value
|
||||
|
||||
@validator("endpoint_api_type")
|
||||
def validate_endpoint_api_type(
|
||||
cls, field_value: Any, values: Dict
|
||||
) -> AzureMLEndpointApiType:
|
||||
"""Validate that endpoint api type is compatible with the URL format."""
|
||||
endpoint_url = values.get("endpoint_url")
|
||||
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
|
||||
"/score"
|
||||
):
|
||||
raise ValueError(
|
||||
"Endpoints of type `realtime` should follow the format "
|
||||
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
|
||||
" If your endpoint URL ends with `/v1/completions` or"
|
||||
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
||||
)
|
||||
if field_value == AzureMLEndpointApiType.serverless and not (
|
||||
endpoint_url.endswith("/v1/completions")
|
||||
or endpoint_url.endswith("/v1/chat/completions")
|
||||
):
|
||||
raise ValueError(
|
||||
"Endpoints of type `serverless` should follow the format "
|
||||
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
|
||||
" or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
@validator("http_client", always=True)
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
endpoint_url = values.get("endpoint_url")
|
||||
endpoint_key = values.get("endpoint_api_key")
|
||||
deployment_name = values.get("deployment_name")
|
||||
|
||||
http_client = AzureMLEndpointClient(
|
||||
endpoint_url, endpoint_key.get_secret_value(), deployment_name
|
||||
)
|
||||
return http_client
|
||||
|
||||
|
||||
class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
|
||||
"""Azure ML Online Endpoint models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_type=AzureMLApiType.realtime,
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
)
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
@@ -262,16 +458,17 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
"""Return type of llm."""
|
||||
return "azureml_endpoint"
|
||||
|
||||
def _call(
|
||||
def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
prompts: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
@@ -280,12 +477,21 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
response = azureml_model("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_model_kwargs.update(kwargs)
|
||||
if stop:
|
||||
_model_kwargs["stop"] = stop
|
||||
generations = []
|
||||
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
prompt, _model_kwargs
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
)
|
||||
return generated_text
|
||||
for prompt in prompts:
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
prompt, _model_kwargs, self.endpoint_api_type
|
||||
)
|
||||
response_payload = self.http_client.call(
|
||||
body=request_payload, run_manager=run_manager
|
||||
)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload, self.endpoint_api_type
|
||||
)
|
||||
generations.append([generated_text])
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
@@ -5,31 +5,31 @@ from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_community.chat_models.azureml_endpoint import (
|
||||
AzureMLChatOnlineEndpoint,
|
||||
LlamaContentFormatter,
|
||||
LlamaChatContentFormatter,
|
||||
)
|
||||
|
||||
|
||||
def test_llama_call() -> None:
|
||||
"""Test valid call to Open Source Foundation Model."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="Foo")])
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
response = chat.invoke([HumanMessage(content="Foo")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_timeout_kwargs() -> None:
|
||||
def test_temperature_kwargs() -> None:
|
||||
"""Test that timeout kwarg works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="FOO")], timeout=60)
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_message_history() -> None:
|
||||
"""Test that multiple messages works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(
|
||||
messages=[
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
response = chat.invoke(
|
||||
[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
@@ -40,7 +40,7 @@ def test_message_history() -> None:
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
message = HumanMessage(content="Hi!")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
|
@@ -7,6 +7,7 @@ from typing import Dict
|
||||
from urllib.request import HTTPError
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
|
||||
from langchain_community.llms.azureml_endpoint import (
|
||||
AzureMLOnlineEndpoint,
|
||||
@@ -26,7 +27,7 @@ def test_gpt2_call() -> None:
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@@ -38,7 +39,7 @@ def test_hf_call() -> None:
|
||||
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
|
||||
content_formatter=HFContentFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@@ -50,7 +51,7 @@ def test_dolly_call() -> None:
|
||||
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
|
||||
content_formatter=DollyContentFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@@ -81,7 +82,7 @@ def test_custom_formatter() -> None:
|
||||
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
|
||||
content_formatter=CustomFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@@ -93,7 +94,7 @@ def test_missing_content_formatter() -> None:
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
)
|
||||
llm("Foo")
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_invalid_request_format() -> None:
|
||||
@@ -123,7 +124,31 @@ def test_invalid_request_format() -> None:
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=CustomContentFormatter(),
|
||||
)
|
||||
llm("Foo")
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_incorrect_url() -> None:
|
||||
"""Testing AzureML Endpoint for an incorrect URL"""
|
||||
with pytest.raises(ValidationError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url="https://endpoint.inference.com",
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_incorrect_api_type() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
endpoint_api_type="serverless",
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_incorrect_key() -> None:
|
||||
@@ -135,7 +160,7 @@ def test_incorrect_key() -> None:
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm("Foo")
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
|
Reference in New Issue
Block a user