feat: adding paygo api support for Azure ML / Azure AI Studio (#14560)

- **Description:** Introducing support for LLMs and Chat models running
in Azure AI studio and Azure ML using the new deployment mode
pay-as-you-go (model as a service).
- **Issue:** NA
- **Dependencies:** None.
- **Tag maintainer:** @prakharg-msft @gdyre 
- **Twitter handle:** @santiagofacundo

Examples added:
*
[docs/docs/integrations/llms/azure_ml.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_endpoint.ipynb)
*
[docs/docs/integrations/chat/azureml_chat_endpoint.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Facundo Santiago 2024-01-23 22:08:51 -03:00 committed by GitHub
parent 9ce177580a
commit 92e6a641fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 630 additions and 206 deletions

View File

@ -15,9 +15,9 @@
"source": [ "source": [
"# AzureMLChatOnlineEndpoint\n", "# AzureMLChatOnlineEndpoint\n",
"\n", "\n",
">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. `Azure Foundation Models` include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n", ">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides foundational and general purpose models from different providers.\n",
">\n", ">\n",
">[Azure Machine Learning Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints). After you train machine learning models or pipelines, you need to deploy them to production so that others can use them for inference. Inference is the process of applying new input data to the machine learning model or pipeline to generate outputs. While these outputs are typically referred to as \"predictions,\" inferencing can be used to generate outputs for other machine learning tasks, such as classification and clustering. In `Azure Machine Learning`, you perform inferencing by using endpoints and deployments. `Endpoints` and `Deployments` allow you to decouple the interface of your production workload from the implementation that serves it.\n", ">In general, you need to deploy models in order to consume its predictions (inference). In `Azure Machine Learning`, [Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints) are used to deploy these models with a real-time serving. They are based on the ideas of `Endpoints` and `Deployments` which allow you to decouple the interface of your production workload from the implementation that serves it.\n",
"\n", "\n",
"This notebook goes over how to use a chat model hosted on an `Azure Machine Learning Endpoint`." "This notebook goes over how to use a chat model hosted on an `Azure Machine Learning Endpoint`."
] ]
@ -37,10 +37,11 @@
"source": [ "source": [
"## Set up\n", "## Set up\n",
"\n", "\n",
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n", "You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
"\n", "\n",
"* `endpoint_api_key`: The API key provided by the endpoint\n", "* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
"* `endpoint_url`: The REST endpoint url provided by the endpoint" "* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
"* `endpoint_api_key`: The API key provided by the endpoint"
] ]
}, },
{ {
@ -51,7 +52,40 @@
"\n", "\n",
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n", "The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
"\n", "\n",
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2-chat" "* `LLamaChatContentFormatter`: Formats request and response data for LLaMa2-chat\n",
"\n",
"*Note: `langchain.chat_models.azureml_endpoint.LLamaContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.LLamaChatContentFormatter`.*\n",
"\n",
"You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Examples\n",
"\n",
"The following section cotain examples about how to use this class:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema import HumanMessage\n",
"from langchain_community.chat_models.azureml_endpoint import (\n",
" AzureMLEndpointApiType,\n",
" LlamaChatContentFormatter,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Chat completions with real-time endpoints"
] ]
}, },
{ {
@ -76,11 +110,79 @@
"\n", "\n",
"chat = AzureMLChatOnlineEndpoint(\n", "chat = AzureMLChatOnlineEndpoint(\n",
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n", " endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
" endpoint_api_key=\"my-api-key\",\n", " endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaContentFormatter,\n", " content_formatter=LlamaChatContentFormatter(),\n",
")\n", ")\n",
"response = chat(\n", "response = chat.invoke(\n",
" messages=[HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n", " [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
")\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Chat completions with pay-as-you-go deployments (model as a service)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat = AzureMLChatOnlineEndpoint(\n",
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
" endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaChatContentFormatter,\n",
")\n",
"response = chat.invoke(\n",
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
")\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you need to pass additional parameters to the model, use `model_kwards` argument:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat = AzureMLChatOnlineEndpoint(\n",
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
" endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaChatContentFormatter,\n",
" model_kwargs={\"temperature\": 0.8},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Parameters can also be passed during invocation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = chat.invoke(\n",
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")],\n",
" max_tokens=512,\n",
")\n", ")\n",
"response" "response"
] ]

View File

@ -6,9 +6,9 @@
"source": [ "source": [
"# Azure ML\n", "# Azure ML\n",
"\n", "\n",
"[Azure ML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n", "[Azure ML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides foundational and general purpose models from different providers.\n",
"\n", "\n",
"This notebook goes over how to use an LLM hosted on an `AzureML online endpoint`" "This notebook goes over how to use an LLM hosted on an `Azure ML Online Endpoint`."
] ]
}, },
{ {
@ -26,11 +26,12 @@
"source": [ "source": [
"## Set up\n", "## Set up\n",
"\n", "\n",
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n", "You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
"\n", "\n",
"* `endpoint_api_key`: Required - The API key provided by the endpoint\n", "* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
"* `endpoint_url`: Required - The REST endpoint url provided by the endpoint\n", "* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
"* `deployment_name`: Not required - The deployment name of the model using the endpoint" "* `endpoint_api_key`: The API key provided by the endpoint.\n",
"* `deployment_name`: (Optional) The deployment name of the model using the endpoint."
] ]
}, },
{ {
@ -46,31 +47,107 @@
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n", "* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n", "* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n",
"\n", "\n",
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*\n", "*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*"
"\n",
"Below is an example using a summarization model from Hugging Face."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Custom Content Formatter" "## Examples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: LlaMa 2 completions with real-time endpoints"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{ "source": [
"name": "stdout", "from langchain.schema import HumanMessage\n",
"output_type": "stream", "from langchain_community.llms.azureml_endpoint import (\n",
"text": [ " AzureMLEndpointApiType,\n",
"HaSeul won her first music show trophy with \"So What\" on Mnet's M Countdown. Loona released their second EP titled [#] (read as hash] on February 5, 2020. HaSeul did not take part in the promotion of the album because of mental health issues. On October 19, 2020, they released their third EP called [12:00]. It was their first album to enter the Billboard 200, debuting at number 112. On June 2, 2021, the group released their fourth EP called Yummy-Yummy. On August 27, it was announced that they are making their Japanese debut on September 15 under Universal Music Japan sublabel EMI Records.\n" " LlamaContentFormatter,\n",
] ")\n",
} "\n",
], "llm = AzureMLOnlineEndpoint(\n",
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
" endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaContentFormatter(),\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
")\n",
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model parameters can also be indicated during invocation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = llm.invoke(\"Write me a song about sparkling water:\", temperature=0.5)\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Chat completions with pay-as-you-go deployments (model as a service)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema import HumanMessage\n",
"from langchain_community.llms.azureml_endpoint import (\n",
" AzureMLEndpointApiType,\n",
" LlamaContentFormatter,\n",
")\n",
"\n",
"llm = AzureMLOnlineEndpoint(\n",
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/completions\",\n",
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
" endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaContentFormatter(),\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
")\n",
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
"response"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Custom content formatter\n",
"\n",
"Below is an example using a summarization model from Hugging Face."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import json\n", "import json\n",
"import os\n", "import os\n",
@ -104,6 +181,7 @@
"content_formatter = CustomFormatter()\n", "content_formatter = CustomFormatter()\n",
"\n", "\n",
"llm = AzureMLOnlineEndpoint(\n", "llm = AzureMLOnlineEndpoint(\n",
" endpoint_api_type=\"realtime\",\n",
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n", " endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n", " endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
@ -132,7 +210,7 @@
"that Loona will release the double A-side single, \"Hula Hoop / Star Seed\" on September 15, with a physical CD release on October \n", "that Loona will release the double A-side single, \"Hula Hoop / Star Seed\" on September 15, with a physical CD release on October \n",
"20.[53] In December, Chuu filed an injunction to suspend her exclusive contract with Blockberry Creative.[54][55]\n", "20.[53] In December, Chuu filed an injunction to suspend her exclusive contract with Blockberry Creative.[54][55]\n",
"\"\"\"\n", "\"\"\"\n",
"summarized_text = llm(large_text)\n", "summarized_text = llm.invoke(large_text)\n",
"print(summarized_text)" "print(summarized_text)"
] ]
}, },
@ -140,22 +218,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Dolly with LLMChain" "### Example: Dolly with LLMChain"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Many people are willing to talk about themselves; it's others who seem to be stuck up. Try to understand others where they're coming from. Like minded people can build a tribe together.\n"
]
}
],
"source": [ "source": [
"from langchain.chains import LLMChain\n", "from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate\n", "from langchain.prompts import PromptTemplate\n",
@ -177,31 +247,22 @@
")\n", ")\n",
"\n", "\n",
"chain = LLMChain(llm=llm, prompt=prompt)\n", "chain = LLMChain(llm=llm, prompt=prompt)\n",
"print(chain.run({\"word_count\": 100, \"topic\": \"how to make friends\"}))" "print(chain.invoke({\"word_count\": 100, \"topic\": \"how to make friends\"}))"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Serializing an LLM\n", "## Serializing an LLM\n",
"You can also save and load LLM configurations" "You can also save and load LLM configurations"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mAzureMLOnlineEndpoint\u001b[0m\n",
"Params: {'deployment_name': 'databricks-dolly-v2-12b-4', 'model_kwargs': {'temperature': 0.2, 'max_tokens': 150, 'top_p': 0.8, 'frequency_penalty': 0.32, 'presence_penalty': 0.072}}\n"
]
}
],
"source": [ "source": [
"from langchain_community.llms.loading import load_llm\n", "from langchain_community.llms.loading import load_llm\n",
"\n", "\n",
@ -224,9 +285,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "langchain",
"language": "python", "language": "python",
"name": "python3" "name": "langchain"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -238,7 +299,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.11.5"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,8 +1,8 @@
import json import json
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import SimpleChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -10,16 +10,24 @@ from langchain_core.messages import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain_core.pydantic_v1 import SecretStr, validator from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.llms.azureml_endpoint import ( from langchain_community.llms.azureml_endpoint import (
AzureMLEndpointClient, AzureMLBaseEndpoint,
AzureMLEndpointApiType,
ContentFormatterBase, ContentFormatterBase,
) )
class LlamaContentFormatter(ContentFormatterBase): class LlamaContentFormatter(ContentFormatterBase):
def __init__(self):
raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead."
)
class LlamaChatContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`.""" """Content formatter for `LLaMA`."""
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"] SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@ -45,7 +53,7 @@ class LlamaContentFormatter(ContentFormatterBase):
} }
elif ( elif (
isinstance(message, ChatMessage) isinstance(message, ChatMessage)
and message.role in LlamaContentFormatter.SUPPORTED_ROLES and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
): ):
return { return {
"role": message.role, "role": message.role,
@ -53,79 +61,96 @@ class LlamaContentFormatter(ContentFormatterBase):
} }
else: else:
supported = ",".join( supported = ",".join(
[role for role in LlamaContentFormatter.SUPPORTED_ROLES] [role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
) )
raise ValueError( raise ValueError(
f"""Received unsupported role. f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}""" Supported roles for the LLaMa Foundation Model: {supported}"""
) )
def _format_request_payload( @property
self, messages: List[BaseMessage], model_kwargs: Dict def supported_api_types(self) -> List[AzureMLEndpointApiType]:
) -> bytes: return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> str:
"""Formats the request according to the chosen api"""
chat_messages = [ chat_messages = [
LlamaContentFormatter._convert_message_to_dict(message) LlamaChatContentFormatter._convert_message_to_dict(message)
for message in messages for message in messages
] ]
prompt = json.dumps( if api_type == AzureMLEndpointApiType.realtime:
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}} request_payload = json.dumps(
) {
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs) "input_data": {
"input_string": chat_messages,
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_response_payload(
"""Formats the request according to the chosen api""" self, output: bytes, api_type: AzureMLEndpointApiType
return str.encode(prompt) ) -> ChatGeneration:
def format_response_payload(self, output: bytes) -> str:
"""Formats response""" """Formats response"""
return json.loads(output)["output"] if api_type == AzureMLEndpointApiType.realtime:
try:
choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
content=choice.strip(),
type="assistant",
),
generation_info=None,
)
if api_type == AzureMLEndpointApiType.serverless:
try:
choice = json.loads(output)["choices"][0]
if not isinstance(choice, dict):
raise TypeError(
"Endpoint response is not well formed for a chat "
"model. Expected `dict` but `{type(choice)}` was received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
content=choice["message"]["content"].strip(),
type=choice["message"]["role"],
),
generation_info=dict(
finish_reason=choice.get("finish_reason"),
logprobs=choice.get("logprobs"),
),
)
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class AzureMLChatOnlineEndpoint(SimpleChatModel): class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
"""`AzureML` Chat models API. """Azure ML Online Endpoint chat models.
Example: Example:
.. code-block:: python .. code-block:: python
azure_llm = AzureMLOnlineEndpoint(
azure_chat = AzureMLChatOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime,
endpoint_api_key="my-api-key", endpoint_api_key="my-api-key",
content_formatter=content_formatter, content_formatter=chat_content_formatter,
) )
""" """ # noqa: E501
endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: SecretStr = convert_to_secret_str("")
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`."""
http_client: Any = None #: :meta private:
content_formatter: Any = None
"""The content formatter that provides an input and output
transform function to handle formats between the LLM and
the endpoint"""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True)
@classmethod
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exist in environment."""
values["endpoint_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
)
endpoint_url = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
)
http_client = AzureMLEndpointClient(
endpoint_url, values["endpoint_api_key"].get_secret_value()
)
return http_client
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
@ -140,13 +165,13 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
"""Return type of llm.""" """Return type of llm."""
return "azureml_chat_endpoint" return "azureml_chat_endpoint"
def _call( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> ChatResult:
"""Call out to an AzureML Managed Online endpoint. """Call out to an AzureML Managed Online endpoint.
Args: Args:
messages: The messages in the conversation with the chat model. messages: The messages in the conversation with the chat model.
@ -158,12 +183,17 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
response = azureml_model("Tell me a joke.") response = azureml_model("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
_model_kwargs.update(kwargs)
if stop:
_model_kwargs["stop"] = stop
request_payload = self.content_formatter._format_request_payload( request_payload = self.content_formatter.format_request_payload(
messages, _model_kwargs messages, _model_kwargs, self.endpoint_api_type
) )
response_payload = self.http_client.call(request_payload, **kwargs) response_payload = self.http_client.call(
generated_text = self.content_formatter.format_response_payload( body=request_payload, run_manager=run_manager
response_payload
) )
return generated_text generations = self.content_formatter.format_response_payload(
response_payload, self.endpoint_api_type
)
return ChatResult(generations=[generations])

View File

@ -2,12 +2,14 @@ import json
import urllib.request import urllib.request
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import BaseLLM
from langchain_core.pydantic_v1 import BaseModel, validator from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import get_from_dict_or_env from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
class AzureMLEndpointClient(object): class AzureMLEndpointClient(object):
@ -26,7 +28,12 @@ class AzureMLEndpointClient(object):
self.endpoint_api_key = endpoint_api_key self.endpoint_api_key = endpoint_api_key
self.deployment_name = deployment_name self.deployment_name = deployment_name
def call(self, body: bytes, **kwargs: Any) -> bytes: def call(
self,
body: bytes,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> bytes:
"""call.""" """call."""
# The azureml-model-deployment header will force the request to go to a # The azureml-model-deployment header will force the request to go to a
@ -45,6 +52,16 @@ class AzureMLEndpointClient(object):
return result return result
class AzureMLEndpointApiType(str, Enum):
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted
infrastructure, or `serverless` for models deployed as a service with a
pay-as-you-go billing or PTU.
"""
realtime = "realtime"
serverless = "serverless"
class ContentFormatterBase: class ContentFormatterBase:
"""Transform request and response of AzureML endpoint to match with """Transform request and response of AzureML endpoint to match with
required schema. required schema.
@ -61,7 +78,8 @@ class ContentFormatterBase:
def format_request_payload( def format_request_payload(
self, self,
prompt: str, prompt: str,
model_kwargs: Dict model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> bytes: ) -> bytes:
input_str = json.dumps( input_str = json.dumps(
{ {
@ -71,7 +89,9 @@ class ContentFormatterBase:
) )
return str.encode(input_str) return str.encode(input_str)
def format_response_payload(self, output: str) -> str: def format_response_payload(
self, output: str, api_type: AzureMLEndpointApiType
) -> str:
response_json = json.loads(output) response_json = json.loads(output)
return response_json[0]["0"] return response_json[0]["0"]
""" """
@ -81,6 +101,12 @@ class ContentFormatterBase:
accepts: Optional[str] = "application/json" accepts: Optional[str] = "application/json"
"""The MIME type of the response data returned from the endpoint""" """The MIME type of the response data returned from the endpoint"""
format_error_msg: Optional[str] = (
"Error while formatting response payload for chat model of type "
" `{api_type}`. Are you using the right formatter for the deployed "
" model and endpoint type?"
)
@staticmethod @staticmethod
def escape_special_characters(prompt: str) -> str: def escape_special_characters(prompt: str) -> str:
"""Escapes any special characters in `prompt`""" """Escapes any special characters in `prompt`"""
@ -100,15 +126,32 @@ class ContentFormatterBase:
return prompt return prompt
@property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
"""Supported APIs for the given formatter. Azure ML supports
deploying models using different hosting methods. Each method may have
a different API structure."""
return [AzureMLEndpointApiType.realtime]
@abstractmethod @abstractmethod
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: def format_request_payload(
self,
prompt: str,
model_kwargs: Dict,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
) -> bytes:
"""Formats the request body according to the input schema of """Formats the request body according to the input schema of
the model. Returns bytes or seekable file like object in the the model. Returns bytes or seekable file like object in the
format specified in the content_type request header. format specified in the content_type request header.
""" """
@abstractmethod @abstractmethod
def format_response_payload(self, output: bytes) -> str: def format_response_payload(
self,
output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
) -> Generation:
"""Formats the response body according to the output """Formats the response body according to the output
schema of the model. Returns the data type that is schema of the model. Returns the data type that is
received from the response. received from the response.
@ -118,15 +161,27 @@ class ContentFormatterBase:
class GPT2ContentFormatter(ContentFormatterBase): class GPT2ContentFormatter(ContentFormatterBase):
"""Content handler for GPT2""" """Content handler for GPT2"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt) prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps( request_payload = json.dumps(
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs} {"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
) )
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(
return json.loads(output)[0]["0"] self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return Generation(text=choice)
class OSSContentFormatter(GPT2ContentFormatter): class OSSContentFormatter(GPT2ContentFormatter):
@ -148,21 +203,39 @@ class OSSContentFormatter(GPT2ContentFormatter):
class HFContentFormatter(ContentFormatterBase): class HFContentFormatter(ContentFormatterBase):
"""Content handler for LLMs from the HuggingFace catalog.""" """Content handler for LLMs from the HuggingFace catalog."""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
ContentFormatterBase.escape_special_characters(prompt) ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps( request_payload = json.dumps(
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs} {"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
) )
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(
return json.loads(output)[0]["generated_text"] self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]["generated_text"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return Generation(text=choice)
class DollyContentFormatter(ContentFormatterBase): class DollyContentFormatter(ContentFormatterBase):
"""Content handler for the Dolly-v2-12b model""" """Content handler for the Dolly-v2-12b model"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt) prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps( request_payload = json.dumps(
{ {
@ -172,49 +245,88 @@ class DollyContentFormatter(ContentFormatterBase):
) )
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(
return json.loads(output)[0] self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return Generation(text=choice)
class LlamaContentFormatter(ContentFormatterBase): class LlamaContentFormatter(ContentFormatterBase):
"""Content formatter for LLaMa""" """Content formatter for LLaMa"""
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload(
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
"""Formats the request according to the chosen api""" """Formats the request according to the chosen api"""
prompt = ContentFormatterBase.escape_special_characters(prompt) prompt = ContentFormatterBase.escape_special_characters(prompt)
request_payload = json.dumps( if api_type == AzureMLEndpointApiType.realtime:
{ request_payload = json.dumps(
"input_data": { {
"input_string": [f'"{prompt}"'], "input_data": {
"parameters": model_kwargs, "input_string": [f'"{prompt}"'],
"parameters": model_kwargs,
}
} }
} )
) elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"prompt": prompt, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload) return str.encode(request_payload)
def format_response_payload(self, output: bytes) -> str: def format_response_payload(
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
"""Formats response""" """Formats response"""
return json.loads(output)[0]["0"] if api_type == AzureMLEndpointApiType.realtime:
try:
choice = json.loads(output)[0]["0"]
class AzureMLOnlineEndpoint(LLM, BaseModel): except (KeyError, IndexError, TypeError) as e:
"""Azure ML Online Endpoint models. raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return Generation(text=choice)
Example: if api_type == AzureMLEndpointApiType.serverless:
.. code-block:: python try:
choice = json.loads(output)["choices"][0]
azure_llm = AzureMLOnlineEndpoint( if not isinstance(choice, dict):
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", raise TypeError(
endpoint_api_key="my-api-key", "Endpoint response is not well formed for a chat "
content_formatter=content_formatter, "model. Expected `dict` but `{type(choice)}` was "
"received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return Generation(
text=choice["text"].strip(),
generation_info=dict(
finish_reason=choice.get("finish_reason"),
logprobs=choice.get("logprobs"),
),
) )
""" # noqa: E501 raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class AzureMLBaseEndpoint(BaseModel):
"""Azure ML Online Endpoint models."""
endpoint_url: str = "" endpoint_url: str = ""
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as """URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`.""" env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_key: str = "" endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime
"""Type of the endpoint being consumed. Possible values are `serverless` for
pay-as-you-go and `realtime` for real-time endpoints. """
endpoint_api_key: SecretStr = convert_to_secret_str("")
"""Authentication Key for Endpoint. Should be passed to constructor or specified as """Authentication Key for Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_API_KEY`.""" env var `AZUREML_ENDPOINT_API_KEY`."""
@ -232,22 +344,106 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model.""" """Keyword arguments to pass to the model."""
@validator("http_client", always=True, allow_reuse=True) @root_validator(pre=True)
@classmethod def validate_environ(cls, values: Dict) -> Dict:
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient: values["endpoint_api_key"] = convert_to_secret_str(
"""Validate that api key and python package exists in environment.""" get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
endpoint_key = get_from_dict_or_env(
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
) )
endpoint_url = get_from_dict_or_env( values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "AZUREML_ENDPOINT_URL" values, "endpoint_url", "AZUREML_ENDPOINT_URL"
) )
deployment_name = get_from_dict_or_env( values["deployment_name"] = get_from_dict_or_env(
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", "" values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
) )
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name) values["endpoint_api_type"] = get_from_dict_or_env(
values,
"endpoint_api_type",
"AZUREML_ENDPOINT_API_TYPE",
AzureMLEndpointApiType.realtime,
)
return values
@validator("content_formatter")
def validate_content_formatter(
cls, field_value: Any, values: Dict
) -> ContentFormatterBase:
"""Validate that content formatter is supported by endpoint type."""
endpoint_api_type = values.get("endpoint_api_type")
if endpoint_api_type not in field_value.supported_api_types:
raise ValueError(
f"Content formatter f{type(field_value)} is not supported by this "
f"endpoint. Supported types are {field_value.supported_api_types} "
f"but endpoint is {endpoint_api_type}."
)
return field_value
@validator("endpoint_url")
def validate_endpoint_url(cls, field_value: Any) -> str:
"""Validate that endpoint url is complete."""
if field_value.endswith("/"):
field_value = field_value[:-1]
if field_value.endswith("inference.ml.azure.com"):
raise ValueError(
"`endpoint_url` should contain the full invocation URL including "
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` "
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
)
return field_value
@validator("endpoint_api_type")
def validate_endpoint_api_type(
cls, field_value: Any, values: Dict
) -> AzureMLEndpointApiType:
"""Validate that endpoint api type is compatible with the URL format."""
endpoint_url = values.get("endpoint_url")
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
"/score"
):
raise ValueError(
"Endpoints of type `realtime` should follow the format "
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
" If your endpoint URL ends with `/v1/completions` or"
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
)
if field_value == AzureMLEndpointApiType.serverless and not (
endpoint_url.endswith("/v1/completions")
or endpoint_url.endswith("/v1/chat/completions")
):
raise ValueError(
"Endpoints of type `serverless` should follow the format "
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
" or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
)
return field_value
@validator("http_client", always=True)
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
"""Validate that api key and python package exists in environment."""
endpoint_url = values.get("endpoint_url")
endpoint_key = values.get("endpoint_api_key")
deployment_name = values.get("deployment_name")
http_client = AzureMLEndpointClient(
endpoint_url, endpoint_key.get_secret_value(), deployment_name
)
return http_client return http_client
class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
"""Azure ML Online Endpoint models.
Example:
.. code-block:: python
azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime,
endpoint_api_key="my-api-key",
content_formatter=content_formatter,
)
""" # noqa: E501
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
@ -262,16 +458,17 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
"""Return type of llm.""" """Return type of llm."""
return "azureml_endpoint" return "azureml_endpoint"
def _call( def _generate(
self, self,
prompt: str, prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> LLMResult:
"""Call out to an AzureML Managed Online endpoint. """Run the LLM on the given prompts.
Args: Args:
prompt: The prompt to pass into the model. prompts: The prompt to pass into the model.
stop: Optional list of stop words to use when generating. stop: Optional list of stop words to use when generating.
Returns: Returns:
The string generated by the model. The string generated by the model.
@ -280,12 +477,21 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
response = azureml_model("Tell me a joke.") response = azureml_model("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
_model_kwargs.update(kwargs)
if stop:
_model_kwargs["stop"] = stop
generations = []
request_payload = self.content_formatter.format_request_payload( for prompt in prompts:
prompt, _model_kwargs request_payload = self.content_formatter.format_request_payload(
) prompt, _model_kwargs, self.endpoint_api_type
response_payload = self.http_client.call(request_payload, **kwargs) )
generated_text = self.content_formatter.format_response_payload( response_payload = self.http_client.call(
response_payload body=request_payload, run_manager=run_manager
) )
return generated_text generated_text = self.content_formatter.format_response_payload(
response_payload, self.endpoint_api_type
)
generations.append([generated_text])
return LLMResult(generations=generations)

View File

@ -5,31 +5,31 @@ from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models.azureml_endpoint import ( from langchain_community.chat_models.azureml_endpoint import (
AzureMLChatOnlineEndpoint, AzureMLChatOnlineEndpoint,
LlamaContentFormatter, LlamaChatContentFormatter,
) )
def test_llama_call() -> None: def test_llama_call() -> None:
"""Test valid call to Open Source Foundation Model.""" """Test valid call to Open Source Foundation Model."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter()) chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
response = chat(messages=[HumanMessage(content="Foo")]) response = chat.invoke([HumanMessage(content="Foo")])
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)
def test_timeout_kwargs() -> None: def test_temperature_kwargs() -> None:
"""Test that timeout kwarg works.""" """Test that timeout kwarg works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter()) chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
response = chat(messages=[HumanMessage(content="FOO")], timeout=60) response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8)
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)
def test_message_history() -> None: def test_message_history() -> None:
"""Test that multiple messages works.""" """Test that multiple messages works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter()) chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
response = chat( response = chat.invoke(
messages=[ [
HumanMessage(content="Hello."), HumanMessage(content="Hello."),
AIMessage(content="Hello!"), AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"), HumanMessage(content="How are you doing?"),
@ -40,7 +40,7 @@ def test_message_history() -> None:
def test_multiple_messages() -> None: def test_multiple_messages() -> None:
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter()) chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
message = HumanMessage(content="Hi!") message = HumanMessage(content="Hi!")
response = chat.generate([[message], [message]]) response = chat.generate([[message], [message]])

View File

@ -7,6 +7,7 @@ from typing import Dict
from urllib.request import HTTPError from urllib.request import HTTPError
import pytest import pytest
from langchain_core.pydantic_v1 import ValidationError
from langchain_community.llms.azureml_endpoint import ( from langchain_community.llms.azureml_endpoint import (
AzureMLOnlineEndpoint, AzureMLOnlineEndpoint,
@ -26,7 +27,7 @@ def test_gpt2_call() -> None:
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(), content_formatter=OSSContentFormatter(),
) )
output = llm("Foo") output = llm.invoke("Foo")
assert isinstance(output, str) assert isinstance(output, str)
@ -38,7 +39,7 @@ def test_hf_call() -> None:
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"), deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
content_formatter=HFContentFormatter(), content_formatter=HFContentFormatter(),
) )
output = llm("Foo") output = llm.invoke("Foo")
assert isinstance(output, str) assert isinstance(output, str)
@ -50,7 +51,7 @@ def test_dolly_call() -> None:
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"), deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
content_formatter=DollyContentFormatter(), content_formatter=DollyContentFormatter(),
) )
output = llm("Foo") output = llm.invoke("Foo")
assert isinstance(output, str) assert isinstance(output, str)
@ -81,7 +82,7 @@ def test_custom_formatter() -> None:
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"), deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
content_formatter=CustomFormatter(), content_formatter=CustomFormatter(),
) )
output = llm("Foo") output = llm.invoke("Foo")
assert isinstance(output, str) assert isinstance(output, str)
@ -93,7 +94,7 @@ def test_missing_content_formatter() -> None:
endpoint_url=os.getenv("OSS_ENDPOINT_URL"), endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
) )
llm("Foo") llm.invoke("Foo")
def test_invalid_request_format() -> None: def test_invalid_request_format() -> None:
@ -123,7 +124,31 @@ def test_invalid_request_format() -> None:
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=CustomContentFormatter(), content_formatter=CustomContentFormatter(),
) )
llm("Foo") llm.invoke("Foo")
def test_incorrect_url() -> None:
"""Testing AzureML Endpoint for an incorrect URL"""
with pytest.raises(ValidationError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url="https://endpoint.inference.com",
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
llm.invoke("Foo")
def test_incorrect_api_type() -> None:
with pytest.raises(ValidationError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
endpoint_api_type="serverless",
content_formatter=OSSContentFormatter(),
)
llm.invoke("Foo")
def test_incorrect_key() -> None: def test_incorrect_key() -> None:
@ -135,7 +160,7 @@ def test_incorrect_key() -> None:
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"), deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(), content_formatter=OSSContentFormatter(),
) )
llm("Foo") llm.invoke("Foo")
def test_saving_loading_llm(tmp_path: Path) -> None: def test_saving_loading_llm(tmp_path: Path) -> None: