mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 22:04:37 +00:00
feat: adding paygo api support for Azure ML / Azure AI Studio (#14560)
- **Description:** Introducing support for LLMs and Chat models running in Azure AI studio and Azure ML using the new deployment mode pay-as-you-go (model as a service). - **Issue:** NA - **Dependencies:** None. - **Tag maintainer:** @prakharg-msft @gdyre - **Twitter handle:** @santiagofacundo Examples added: * [docs/docs/integrations/llms/azure_ml.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_endpoint.ipynb) * [docs/docs/integrations/chat/azureml_chat_endpoint.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
9ce177580a
commit
92e6a641fd
@ -15,9 +15,9 @@
|
||||
"source": [
|
||||
"# AzureMLChatOnlineEndpoint\n",
|
||||
"\n",
|
||||
">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. `Azure Foundation Models` include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
|
||||
">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides foundational and general purpose models from different providers.\n",
|
||||
">\n",
|
||||
">[Azure Machine Learning Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints). After you train machine learning models or pipelines, you need to deploy them to production so that others can use them for inference. Inference is the process of applying new input data to the machine learning model or pipeline to generate outputs. While these outputs are typically referred to as \"predictions,\" inferencing can be used to generate outputs for other machine learning tasks, such as classification and clustering. In `Azure Machine Learning`, you perform inferencing by using endpoints and deployments. `Endpoints` and `Deployments` allow you to decouple the interface of your production workload from the implementation that serves it.\n",
|
||||
">In general, you need to deploy models in order to consume its predictions (inference). In `Azure Machine Learning`, [Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints) are used to deploy these models with a real-time serving. They are based on the ideas of `Endpoints` and `Deployments` which allow you to decouple the interface of your production workload from the implementation that serves it.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use a chat model hosted on an `Azure Machine Learning Endpoint`."
|
||||
]
|
||||
@ -37,10 +37,11 @@
|
||||
"source": [
|
||||
"## Set up\n",
|
||||
"\n",
|
||||
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
|
||||
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
||||
"\n",
|
||||
"* `endpoint_api_key`: The API key provided by the endpoint\n",
|
||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint"
|
||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
||||
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
||||
"* `endpoint_api_key`: The API key provided by the endpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -51,7 +52,40 @@
|
||||
"\n",
|
||||
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
|
||||
"\n",
|
||||
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2-chat"
|
||||
"* `LLamaChatContentFormatter`: Formats request and response data for LLaMa2-chat\n",
|
||||
"\n",
|
||||
"*Note: `langchain.chat_models.azureml_endpoint.LLamaContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.LLamaChatContentFormatter`.*\n",
|
||||
"\n",
|
||||
"You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Examples\n",
|
||||
"\n",
|
||||
"The following section cotain examples about how to use this class:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"from langchain_community.chat_models.azureml_endpoint import (\n",
|
||||
" AzureMLEndpointApiType,\n",
|
||||
" LlamaChatContentFormatter,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example: Chat completions with real-time endpoints"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -76,11 +110,79 @@
|
||||
"\n",
|
||||
"chat = AzureMLChatOnlineEndpoint(\n",
|
||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
||||
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
|
||||
" endpoint_api_key=\"my-api-key\",\n",
|
||||
" content_formatter=LlamaContentFormatter,\n",
|
||||
" content_formatter=LlamaChatContentFormatter(),\n",
|
||||
")\n",
|
||||
"response = chat(\n",
|
||||
" messages=[HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
||||
"response = chat.invoke(\n",
|
||||
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
||||
")\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example: Chat completions with pay-as-you-go deployments (model as a service)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = AzureMLChatOnlineEndpoint(\n",
|
||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
||||
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||
" endpoint_api_key=\"my-api-key\",\n",
|
||||
" content_formatter=LlamaChatContentFormatter,\n",
|
||||
")\n",
|
||||
"response = chat.invoke(\n",
|
||||
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
||||
")\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you need to pass additional parameters to the model, use `model_kwards` argument:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chat = AzureMLChatOnlineEndpoint(\n",
|
||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
||||
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||
" endpoint_api_key=\"my-api-key\",\n",
|
||||
" content_formatter=LlamaChatContentFormatter,\n",
|
||||
" model_kwargs={\"temperature\": 0.8},\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Parameters can also be passed during invocation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = chat.invoke(\n",
|
||||
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")],\n",
|
||||
" max_tokens=512,\n",
|
||||
")\n",
|
||||
"response"
|
||||
]
|
||||
|
@ -6,9 +6,9 @@
|
||||
"source": [
|
||||
"# Azure ML\n",
|
||||
"\n",
|
||||
"[Azure ML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
|
||||
"[Azure ML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides foundational and general purpose models from different providers.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use an LLM hosted on an `AzureML online endpoint`"
|
||||
"This notebook goes over how to use an LLM hosted on an `Azure ML Online Endpoint`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -26,11 +26,12 @@
|
||||
"source": [
|
||||
"## Set up\n",
|
||||
"\n",
|
||||
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
|
||||
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
||||
"\n",
|
||||
"* `endpoint_api_key`: Required - The API key provided by the endpoint\n",
|
||||
"* `endpoint_url`: Required - The REST endpoint url provided by the endpoint\n",
|
||||
"* `deployment_name`: Not required - The deployment name of the model using the endpoint"
|
||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
||||
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
||||
"* `endpoint_api_key`: The API key provided by the endpoint.\n",
|
||||
"* `deployment_name`: (Optional) The deployment name of the model using the endpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -46,31 +47,107 @@
|
||||
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
|
||||
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n",
|
||||
"\n",
|
||||
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*\n",
|
||||
"\n",
|
||||
"Below is an example using a summarization model from Hugging Face."
|
||||
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Custom Content Formatter"
|
||||
"## Examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example: LlaMa 2 completions with real-time endpoints"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"HaSeul won her first music show trophy with \"So What\" on Mnet's M Countdown. Loona released their second EP titled [#] (read as hash] on February 5, 2020. HaSeul did not take part in the promotion of the album because of mental health issues. On October 19, 2020, they released their third EP called [12:00]. It was their first album to enter the Billboard 200, debuting at number 112. On June 2, 2021, the group released their fourth EP called Yummy-Yummy. On August 27, it was announced that they are making their Japanese debut on September 15 under Universal Music Japan sublabel EMI Records.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"from langchain_community.llms.azureml_endpoint import (\n",
|
||||
" AzureMLEndpointApiType,\n",
|
||||
" LlamaContentFormatter,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"llm = AzureMLOnlineEndpoint(\n",
|
||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
||||
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
|
||||
" endpoint_api_key=\"my-api-key\",\n",
|
||||
" content_formatter=LlamaContentFormatter(),\n",
|
||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||
")\n",
|
||||
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Model parameters can also be indicated during invocation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = llm.invoke(\"Write me a song about sparkling water:\", temperature=0.5)\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example: Chat completions with pay-as-you-go deployments (model as a service)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.schema import HumanMessage\n",
|
||||
"from langchain_community.llms.azureml_endpoint import (\n",
|
||||
" AzureMLEndpointApiType,\n",
|
||||
" LlamaContentFormatter,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"llm = AzureMLOnlineEndpoint(\n",
|
||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/completions\",\n",
|
||||
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||
" endpoint_api_key=\"my-api-key\",\n",
|
||||
" content_formatter=LlamaContentFormatter(),\n",
|
||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||
")\n",
|
||||
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
||||
"response"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example: Custom content formatter\n",
|
||||
"\n",
|
||||
"Below is an example using a summarization model from Hugging Face."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
@ -104,6 +181,7 @@
|
||||
"content_formatter = CustomFormatter()\n",
|
||||
"\n",
|
||||
"llm = AzureMLOnlineEndpoint(\n",
|
||||
" endpoint_api_type=\"realtime\",\n",
|
||||
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
|
||||
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
|
||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||
@ -132,7 +210,7 @@
|
||||
"that Loona will release the double A-side single, \"Hula Hoop / Star Seed\" on September 15, with a physical CD release on October \n",
|
||||
"20.[53] In December, Chuu filed an injunction to suspend her exclusive contract with Blockberry Creative.[54][55]\n",
|
||||
"\"\"\"\n",
|
||||
"summarized_text = llm(large_text)\n",
|
||||
"summarized_text = llm.invoke(large_text)\n",
|
||||
"print(summarized_text)"
|
||||
]
|
||||
},
|
||||
@ -140,22 +218,14 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Dolly with LLMChain"
|
||||
"### Example: Dolly with LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Many people are willing to talk about themselves; it's others who seem to be stuck up. Try to understand others where they're coming from. Like minded people can build a tribe together.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
@ -177,31 +247,22 @@
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
||||
"print(chain.run({\"word_count\": 100, \"topic\": \"how to make friends\"}))"
|
||||
"print(chain.invoke({\"word_count\": 100, \"topic\": \"how to make friends\"}))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Serializing an LLM\n",
|
||||
"## Serializing an LLM\n",
|
||||
"You can also save and load LLM configurations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[1mAzureMLOnlineEndpoint\u001b[0m\n",
|
||||
"Params: {'deployment_name': 'databricks-dolly-v2-12b-4', 'model_kwargs': {'temperature': 0.2, 'max_tokens': 150, 'top_p': 0.8, 'frequency_penalty': 0.32, 'presence_penalty': 0.072}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.llms.loading import load_llm\n",
|
||||
"\n",
|
||||
@ -224,9 +285,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "langchain",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "langchain"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@ -238,7 +299,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -1,8 +1,8 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@ -10,16 +10,24 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr, validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain_community.llms.azureml_endpoint import (
|
||||
AzureMLEndpointClient,
|
||||
AzureMLBaseEndpoint,
|
||||
AzureMLEndpointApiType,
|
||||
ContentFormatterBase,
|
||||
)
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
def __init__(self):
|
||||
raise TypeError(
|
||||
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||
"`LlamaChatContentFormatter` instead."
|
||||
)
|
||||
|
||||
|
||||
class LlamaChatContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for `LLaMA`."""
|
||||
|
||||
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
||||
@ -45,7 +53,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
}
|
||||
elif (
|
||||
isinstance(message, ChatMessage)
|
||||
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
|
||||
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
|
||||
):
|
||||
return {
|
||||
"role": message.role,
|
||||
@ -53,79 +61,96 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
}
|
||||
else:
|
||||
supported = ",".join(
|
||||
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
|
||||
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
|
||||
)
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||
)
|
||||
|
||||
def _format_request_payload(
|
||||
self, messages: List[BaseMessage], model_kwargs: Dict
|
||||
) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_request_payload(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType,
|
||||
) -> str:
|
||||
"""Formats the request according to the chosen api"""
|
||||
chat_messages = [
|
||||
LlamaContentFormatter._convert_message_to_dict(message)
|
||||
LlamaChatContentFormatter._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
prompt = json.dumps(
|
||||
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
|
||||
)
|
||||
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": chat_messages,
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
)
|
||||
elif api_type == AzureMLEndpointApiType.serverless:
|
||||
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
return str.encode(prompt)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> ChatGeneration:
|
||||
"""Formats response"""
|
||||
return json.loads(output)["output"]
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
try:
|
||||
choice = json.loads(output)["output"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
content=choice.strip(),
|
||||
type="assistant",
|
||||
),
|
||||
generation_info=None,
|
||||
)
|
||||
if api_type == AzureMLEndpointApiType.serverless:
|
||||
try:
|
||||
choice = json.loads(output)["choices"][0]
|
||||
if not isinstance(choice, dict):
|
||||
raise TypeError(
|
||||
"Endpoint response is not well formed for a chat "
|
||||
"model. Expected `dict` but `{type(choice)}` was received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return ChatGeneration(
|
||||
message=BaseMessage(
|
||||
content=choice["message"]["content"].strip(),
|
||||
type=choice["message"]["role"],
|
||||
),
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
logprobs=choice.get("logprobs"),
|
||||
),
|
||||
)
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
"""`AzureML` Chat models API.
|
||||
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||
"""Azure ML Online Endpoint chat models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_chat = AzureMLChatOnlineEndpoint(
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_type=AzureMLApiType.realtime,
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
content_formatter=chat_content_formatter,
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint_url: str = ""
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
http_client: Any = None #: :meta private:
|
||||
|
||||
content_formatter: Any = None
|
||||
"""The content formatter that provides an input and output
|
||||
transform function to handle formats between the LLM and
|
||||
the endpoint"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
@validator("http_client", always=True, allow_reuse=True)
|
||||
@classmethod
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exist in environment."""
|
||||
values["endpoint_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
||||
)
|
||||
endpoint_url = get_from_dict_or_env(
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
http_client = AzureMLEndpointClient(
|
||||
endpoint_url, values["endpoint_api_key"].get_secret_value()
|
||||
)
|
||||
return http_client
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
@ -140,13 +165,13 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
"""Return type of llm."""
|
||||
return "azureml_chat_endpoint"
|
||||
|
||||
def _call(
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> ChatResult:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
Args:
|
||||
messages: The messages in the conversation with the chat model.
|
||||
@ -158,12 +183,17 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
||||
response = azureml_model("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_model_kwargs.update(kwargs)
|
||||
if stop:
|
||||
_model_kwargs["stop"] = stop
|
||||
|
||||
request_payload = self.content_formatter._format_request_payload(
|
||||
messages, _model_kwargs
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
messages, _model_kwargs, self.endpoint_api_type
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
response_payload = self.http_client.call(
|
||||
body=request_payload, run_manager=run_manager
|
||||
)
|
||||
return generated_text
|
||||
generations = self.content_formatter.format_response_payload(
|
||||
response_payload, self.endpoint_api_type
|
||||
)
|
||||
return ChatResult(generations=[generations])
|
||||
|
@ -2,12 +2,14 @@ import json
|
||||
import urllib.request
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
class AzureMLEndpointClient(object):
|
||||
@ -26,7 +28,12 @@ class AzureMLEndpointClient(object):
|
||||
self.endpoint_api_key = endpoint_api_key
|
||||
self.deployment_name = deployment_name
|
||||
|
||||
def call(self, body: bytes, **kwargs: Any) -> bytes:
|
||||
def call(
|
||||
self,
|
||||
body: bytes,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> bytes:
|
||||
"""call."""
|
||||
|
||||
# The azureml-model-deployment header will force the request to go to a
|
||||
@ -45,6 +52,16 @@ class AzureMLEndpointClient(object):
|
||||
return result
|
||||
|
||||
|
||||
class AzureMLEndpointApiType(str, Enum):
|
||||
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted
|
||||
infrastructure, or `serverless` for models deployed as a service with a
|
||||
pay-as-you-go billing or PTU.
|
||||
"""
|
||||
|
||||
realtime = "realtime"
|
||||
serverless = "serverless"
|
||||
|
||||
|
||||
class ContentFormatterBase:
|
||||
"""Transform request and response of AzureML endpoint to match with
|
||||
required schema.
|
||||
@ -61,7 +78,8 @@ class ContentFormatterBase:
|
||||
def format_request_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model_kwargs: Dict
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType,
|
||||
) -> bytes:
|
||||
input_str = json.dumps(
|
||||
{
|
||||
@ -71,7 +89,9 @@ class ContentFormatterBase:
|
||||
)
|
||||
return str.encode(input_str)
|
||||
|
||||
def format_response_payload(self, output: str) -> str:
|
||||
def format_response_payload(
|
||||
self, output: str, api_type: AzureMLEndpointApiType
|
||||
) -> str:
|
||||
response_json = json.loads(output)
|
||||
return response_json[0]["0"]
|
||||
"""
|
||||
@ -81,6 +101,12 @@ class ContentFormatterBase:
|
||||
accepts: Optional[str] = "application/json"
|
||||
"""The MIME type of the response data returned from the endpoint"""
|
||||
|
||||
format_error_msg: Optional[str] = (
|
||||
"Error while formatting response payload for chat model of type "
|
||||
" `{api_type}`. Are you using the right formatter for the deployed "
|
||||
" model and endpoint type?"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def escape_special_characters(prompt: str) -> str:
|
||||
"""Escapes any special characters in `prompt`"""
|
||||
@ -100,15 +126,32 @@ class ContentFormatterBase:
|
||||
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
"""Supported APIs for the given formatter. Azure ML supports
|
||||
deploying models using different hosting methods. Each method may have
|
||||
a different API structure."""
|
||||
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
@abstractmethod
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
def format_request_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
model_kwargs: Dict,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
) -> bytes:
|
||||
"""Formats the request body according to the input schema of
|
||||
the model. Returns bytes or seekable file like object in the
|
||||
format specified in the content_type request header.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
def format_response_payload(
|
||||
self,
|
||||
output: bytes,
|
||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||
) -> Generation:
|
||||
"""Formats the response body according to the output
|
||||
schema of the model. Returns the data type that is
|
||||
received from the response.
|
||||
@ -118,15 +161,27 @@ class ContentFormatterBase:
|
||||
class GPT2ContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for GPT2"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
return json.loads(output)[0]["0"]
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
class OSSContentFormatter(GPT2ContentFormatter):
|
||||
@ -148,21 +203,39 @@ class OSSContentFormatter(GPT2ContentFormatter):
|
||||
class HFContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for LLMs from the HuggingFace catalog."""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
return json.loads(output)[0]["generated_text"]
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]["generated_text"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
class DollyContentFormatter(ContentFormatterBase):
|
||||
"""Content handler for the Dolly-v2-12b model"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
@ -172,49 +245,88 @@ class DollyContentFormatter(ContentFormatterBase):
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
return json.loads(output)[0]
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
class LlamaContentFormatter(ContentFormatterBase):
|
||||
"""Content formatter for LLaMa"""
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
||||
@property
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_request_payload(
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": [f'"{prompt}"'],
|
||||
"parameters": model_kwargs,
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
request_payload = json.dumps(
|
||||
{
|
||||
"input_data": {
|
||||
"input_string": [f'"{prompt}"'],
|
||||
"parameters": model_kwargs,
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
elif api_type == AzureMLEndpointApiType.serverless:
|
||||
request_payload = json.dumps({"prompt": prompt, **model_kwargs})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`api_type` {api_type} is not supported by this formatter"
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(self, output: bytes) -> str:
|
||||
def format_response_payload(
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
"""Formats response"""
|
||||
return json.loads(output)[0]["0"]
|
||||
|
||||
|
||||
class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
"""Azure ML Online Endpoint models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
if api_type == AzureMLEndpointApiType.realtime:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(text=choice)
|
||||
if api_type == AzureMLEndpointApiType.serverless:
|
||||
try:
|
||||
choice = json.loads(output)["choices"][0]
|
||||
if not isinstance(choice, dict):
|
||||
raise TypeError(
|
||||
"Endpoint response is not well formed for a chat "
|
||||
"model. Expected `dict` but `{type(choice)}` was "
|
||||
"received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
return Generation(
|
||||
text=choice["text"].strip(),
|
||||
generation_info=dict(
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
logprobs=choice.get("logprobs"),
|
||||
),
|
||||
)
|
||||
""" # noqa: E501
|
||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||
|
||||
|
||||
class AzureMLBaseEndpoint(BaseModel):
|
||||
"""Azure ML Online Endpoint models."""
|
||||
|
||||
endpoint_url: str = ""
|
||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_URL`."""
|
||||
|
||||
endpoint_api_key: str = ""
|
||||
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime
|
||||
"""Type of the endpoint being consumed. Possible values are `serverless` for
|
||||
pay-as-you-go and `realtime` for real-time endpoints. """
|
||||
|
||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||
|
||||
@ -232,22 +344,106 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
@validator("http_client", always=True, allow_reuse=True)
|
||||
@classmethod
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
endpoint_key = get_from_dict_or_env(
|
||||
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
|
||||
@root_validator(pre=True)
|
||||
def validate_environ(cls, values: Dict) -> Dict:
|
||||
values["endpoint_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
||||
)
|
||||
endpoint_url = get_from_dict_or_env(
|
||||
values["endpoint_url"] = get_from_dict_or_env(
|
||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||
)
|
||||
deployment_name = get_from_dict_or_env(
|
||||
values["deployment_name"] = get_from_dict_or_env(
|
||||
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
|
||||
)
|
||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
|
||||
values["endpoint_api_type"] = get_from_dict_or_env(
|
||||
values,
|
||||
"endpoint_api_type",
|
||||
"AZUREML_ENDPOINT_API_TYPE",
|
||||
AzureMLEndpointApiType.realtime,
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@validator("content_formatter")
|
||||
def validate_content_formatter(
|
||||
cls, field_value: Any, values: Dict
|
||||
) -> ContentFormatterBase:
|
||||
"""Validate that content formatter is supported by endpoint type."""
|
||||
endpoint_api_type = values.get("endpoint_api_type")
|
||||
if endpoint_api_type not in field_value.supported_api_types:
|
||||
raise ValueError(
|
||||
f"Content formatter f{type(field_value)} is not supported by this "
|
||||
f"endpoint. Supported types are {field_value.supported_api_types} "
|
||||
f"but endpoint is {endpoint_api_type}."
|
||||
)
|
||||
return field_value
|
||||
|
||||
@validator("endpoint_url")
|
||||
def validate_endpoint_url(cls, field_value: Any) -> str:
|
||||
"""Validate that endpoint url is complete."""
|
||||
if field_value.endswith("/"):
|
||||
field_value = field_value[:-1]
|
||||
if field_value.endswith("inference.ml.azure.com"):
|
||||
raise ValueError(
|
||||
"`endpoint_url` should contain the full invocation URL including "
|
||||
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` "
|
||||
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
|
||||
)
|
||||
return field_value
|
||||
|
||||
@validator("endpoint_api_type")
|
||||
def validate_endpoint_api_type(
|
||||
cls, field_value: Any, values: Dict
|
||||
) -> AzureMLEndpointApiType:
|
||||
"""Validate that endpoint api type is compatible with the URL format."""
|
||||
endpoint_url = values.get("endpoint_url")
|
||||
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
|
||||
"/score"
|
||||
):
|
||||
raise ValueError(
|
||||
"Endpoints of type `realtime` should follow the format "
|
||||
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
|
||||
" If your endpoint URL ends with `/v1/completions` or"
|
||||
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
||||
)
|
||||
if field_value == AzureMLEndpointApiType.serverless and not (
|
||||
endpoint_url.endswith("/v1/completions")
|
||||
or endpoint_url.endswith("/v1/chat/completions")
|
||||
):
|
||||
raise ValueError(
|
||||
"Endpoints of type `serverless` should follow the format "
|
||||
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
|
||||
" or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
|
||||
)
|
||||
|
||||
return field_value
|
||||
|
||||
@validator("http_client", always=True)
|
||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
endpoint_url = values.get("endpoint_url")
|
||||
endpoint_key = values.get("endpoint_api_key")
|
||||
deployment_name = values.get("deployment_name")
|
||||
|
||||
http_client = AzureMLEndpointClient(
|
||||
endpoint_url, endpoint_key.get_secret_value(), deployment_name
|
||||
)
|
||||
return http_client
|
||||
|
||||
|
||||
class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
|
||||
"""Azure ML Online Endpoint models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
azure_llm = AzureMLOnlineEndpoint(
|
||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||
endpoint_api_type=AzureMLApiType.realtime,
|
||||
endpoint_api_key="my-api-key",
|
||||
content_formatter=content_formatter,
|
||||
)
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
@ -262,16 +458,17 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
"""Return type of llm."""
|
||||
return "azureml_endpoint"
|
||||
|
||||
def _call(
|
||||
def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an AzureML Managed Online endpoint.
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
prompts: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
@ -280,12 +477,21 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
||||
response = azureml_model("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_model_kwargs.update(kwargs)
|
||||
if stop:
|
||||
_model_kwargs["stop"] = stop
|
||||
generations = []
|
||||
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
prompt, _model_kwargs
|
||||
)
|
||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload
|
||||
)
|
||||
return generated_text
|
||||
for prompt in prompts:
|
||||
request_payload = self.content_formatter.format_request_payload(
|
||||
prompt, _model_kwargs, self.endpoint_api_type
|
||||
)
|
||||
response_payload = self.http_client.call(
|
||||
body=request_payload, run_manager=run_manager
|
||||
)
|
||||
generated_text = self.content_formatter.format_response_payload(
|
||||
response_payload, self.endpoint_api_type
|
||||
)
|
||||
generations.append([generated_text])
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
@ -5,31 +5,31 @@ from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_community.chat_models.azureml_endpoint import (
|
||||
AzureMLChatOnlineEndpoint,
|
||||
LlamaContentFormatter,
|
||||
LlamaChatContentFormatter,
|
||||
)
|
||||
|
||||
|
||||
def test_llama_call() -> None:
|
||||
"""Test valid call to Open Source Foundation Model."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="Foo")])
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
response = chat.invoke([HumanMessage(content="Foo")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_timeout_kwargs() -> None:
|
||||
def test_temperature_kwargs() -> None:
|
||||
"""Test that timeout kwarg works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(messages=[HumanMessage(content="FOO")], timeout=60)
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_message_history() -> None:
|
||||
"""Test that multiple messages works."""
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
response = chat(
|
||||
messages=[
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
response = chat.invoke(
|
||||
[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
@ -40,7 +40,7 @@ def test_message_history() -> None:
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||
message = HumanMessage(content="Hi!")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
|
@ -7,6 +7,7 @@ from typing import Dict
|
||||
from urllib.request import HTTPError
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import ValidationError
|
||||
|
||||
from langchain_community.llms.azureml_endpoint import (
|
||||
AzureMLOnlineEndpoint,
|
||||
@ -26,7 +27,7 @@ def test_gpt2_call() -> None:
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@ -38,7 +39,7 @@ def test_hf_call() -> None:
|
||||
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
|
||||
content_formatter=HFContentFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@ -50,7 +51,7 @@ def test_dolly_call() -> None:
|
||||
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
|
||||
content_formatter=DollyContentFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@ -81,7 +82,7 @@ def test_custom_formatter() -> None:
|
||||
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
|
||||
content_formatter=CustomFormatter(),
|
||||
)
|
||||
output = llm("Foo")
|
||||
output = llm.invoke("Foo")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@ -93,7 +94,7 @@ def test_missing_content_formatter() -> None:
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
)
|
||||
llm("Foo")
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_invalid_request_format() -> None:
|
||||
@ -123,7 +124,31 @@ def test_invalid_request_format() -> None:
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=CustomContentFormatter(),
|
||||
)
|
||||
llm("Foo")
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_incorrect_url() -> None:
|
||||
"""Testing AzureML Endpoint for an incorrect URL"""
|
||||
with pytest.raises(ValidationError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url="https://endpoint.inference.com",
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_incorrect_api_type() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
llm = AzureMLOnlineEndpoint(
|
||||
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
endpoint_api_type="serverless",
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_incorrect_key() -> None:
|
||||
@ -135,7 +160,7 @@ def test_incorrect_key() -> None:
|
||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||
content_formatter=OSSContentFormatter(),
|
||||
)
|
||||
llm("Foo")
|
||||
llm.invoke("Foo")
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user