feat: adding paygo api support for Azure ML / Azure AI Studio (#14560)

- **Description:** Introducing support for LLMs and Chat models running
in Azure AI studio and Azure ML using the new deployment mode
pay-as-you-go (model as a service).
- **Issue:** NA
- **Dependencies:** None.
- **Tag maintainer:** @prakharg-msft @gdyre 
- **Twitter handle:** @santiagofacundo

Examples added:
*
[docs/docs/integrations/llms/azure_ml.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_endpoint.ipynb)
*
[docs/docs/integrations/chat/azureml_chat_endpoint.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Facundo Santiago
2024-01-23 22:08:51 -03:00
committed by GitHub
parent 9ce177580a
commit 92e6a641fd
6 changed files with 630 additions and 206 deletions

View File

@@ -1,8 +1,8 @@
import json
from typing import Any, Dict, List, Optional, cast
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import SimpleChatModel
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
@@ -10,16 +10,24 @@ from langchain_core.messages import (
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import SecretStr, validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_community.llms.azureml_endpoint import (
AzureMLEndpointClient,
AzureMLBaseEndpoint,
AzureMLEndpointApiType,
ContentFormatterBase,
)
class LlamaContentFormatter(ContentFormatterBase):
def __init__(self):
raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead."
)
class LlamaChatContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`."""
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@@ -45,7 +53,7 @@ class LlamaContentFormatter(ContentFormatterBase):
}
elif (
isinstance(message, ChatMessage)
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
):
return {
"role": message.role,
@@ -53,79 +61,96 @@ class LlamaContentFormatter(ContentFormatterBase):
}
else:
supported = ",".join(
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
)
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
def _format_request_payload(
self, messages: List[BaseMessage], model_kwargs: Dict
) -> bytes:
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> str:
"""Formats the request according to the chosen api"""
chat_messages = [
LlamaContentFormatter._convert_message_to_dict(message)
LlamaChatContentFormatter._convert_message_to_dict(message)
for message in messages
]
prompt = json.dumps(
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
)
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
if api_type == AzureMLEndpointApiType.realtime:
request_payload = json.dumps(
{
"input_data": {
"input_string": chat_messages,
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
"""Formats the request according to the chosen api"""
return str.encode(prompt)
def format_response_payload(self, output: bytes) -> str:
def format_response_payload(
self, output: bytes, api_type: AzureMLEndpointApiType
) -> ChatGeneration:
"""Formats response"""
return json.loads(output)["output"]
if api_type == AzureMLEndpointApiType.realtime:
try:
choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
content=choice.strip(),
type="assistant",
),
generation_info=None,
)
if api_type == AzureMLEndpointApiType.serverless:
try:
choice = json.loads(output)["choices"][0]
if not isinstance(choice, dict):
raise TypeError(
"Endpoint response is not well formed for a chat "
"model. Expected `dict` but `{type(choice)}` was received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
content=choice["message"]["content"].strip(),
type=choice["message"]["role"],
),
generation_info=dict(
finish_reason=choice.get("finish_reason"),
logprobs=choice.get("logprobs"),
),
)
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class AzureMLChatOnlineEndpoint(SimpleChatModel):
"""`AzureML` Chat models API.
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
"""Azure ML Online Endpoint chat models.
Example:
.. code-block:: python
azure_chat = AzureMLChatOnlineEndpoint(
azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime,
endpoint_api_key="my-api-key",
content_formatter=content_formatter,
content_formatter=chat_content_formatter,
)
"""
endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: SecretStr = convert_to_secret_str("")
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`."""
http_client: Any = None #: :meta private:
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
the endpoint"""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True)
@classmethod
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exist in environment."""
values["endpoint_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
)
endpoint_url = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
http_client = AzureMLEndpointClient(
endpoint_url, values["endpoint_api_key"].get_secret_value()
)
return http_client
""" # noqa: E501
@property
def _identifying_params(self) -> Dict[str, Any]:
@@ -140,13 +165,13 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
"""Return type of llm."""
return "azureml_chat_endpoint"
def _call(
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
) -> ChatResult:
"""Call out to an AzureML Managed Online endpoint.
Args:
messages: The messages in the conversation with the chat model.
@@ -158,12 +183,17 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
response = azureml_model("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs.update(kwargs)
if stop:
_model_kwargs["stop"] = stop
request_payload = self.content_formatter._format_request_payload(
messages, _model_kwargs
request_payload = self.content_formatter.format_request_payload(
messages, _model_kwargs, self.endpoint_api_type
)
response_payload = self.http_client.call(request_payload, **kwargs)
generated_text = self.content_formatter.format_response_payload(
response_payload
response_payload = self.http_client.call(
body=request_payload, run_manager=run_manager
)
return generated_text
generations = self.content_formatter.format_response_payload(
response_payload, self.endpoint_api_type
)
return ChatResult(generations=[generations])