mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
feat: adding paygo api support for Azure ML / Azure AI Studio (#14560)
- **Description:** Introducing support for LLMs and Chat models running in Azure AI studio and Azure ML using the new deployment mode pay-as-you-go (model as a service). - **Issue:** NA - **Dependencies:** None. - **Tag maintainer:** @prakharg-msft @gdyre - **Twitter handle:** @santiagofacundo Examples added: * [docs/docs/integrations/llms/azure_ml.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_endpoint.ipynb) * [docs/docs/integrations/chat/azureml_chat_endpoint.ipynb](https://github.com/santiagxf/langchain/blob/santiagxf/azureml-endpoints-paygo-community/docs/docs/integrations/chat/azureml_chat_endpoint.ipynb) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
9ce177580a
commit
92e6a641fd
@ -15,9 +15,9 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# AzureMLChatOnlineEndpoint\n",
|
"# AzureMLChatOnlineEndpoint\n",
|
||||||
"\n",
|
"\n",
|
||||||
">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. `Azure Foundation Models` include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
|
">[Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides foundational and general purpose models from different providers.\n",
|
||||||
">\n",
|
">\n",
|
||||||
">[Azure Machine Learning Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints). After you train machine learning models or pipelines, you need to deploy them to production so that others can use them for inference. Inference is the process of applying new input data to the machine learning model or pipeline to generate outputs. While these outputs are typically referred to as \"predictions,\" inferencing can be used to generate outputs for other machine learning tasks, such as classification and clustering. In `Azure Machine Learning`, you perform inferencing by using endpoints and deployments. `Endpoints` and `Deployments` allow you to decouple the interface of your production workload from the implementation that serves it.\n",
|
">In general, you need to deploy models in order to consume its predictions (inference). In `Azure Machine Learning`, [Online Endpoints](https://learn.microsoft.com/en-us/azure/machine-learning/concept-endpoints) are used to deploy these models with a real-time serving. They are based on the ideas of `Endpoints` and `Deployments` which allow you to decouple the interface of your production workload from the implementation that serves it.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook goes over how to use a chat model hosted on an `Azure Machine Learning Endpoint`."
|
"This notebook goes over how to use a chat model hosted on an `Azure Machine Learning Endpoint`."
|
||||||
]
|
]
|
||||||
@ -37,10 +37,11 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Set up\n",
|
"## Set up\n",
|
||||||
"\n",
|
"\n",
|
||||||
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
|
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `endpoint_api_key`: The API key provided by the endpoint\n",
|
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
||||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint"
|
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
||||||
|
"* `endpoint_api_key`: The API key provided by the endpoint"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -51,7 +52,40 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
|
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2-chat"
|
"* `LLamaChatContentFormatter`: Formats request and response data for LLaMa2-chat\n",
|
||||||
|
"\n",
|
||||||
|
"*Note: `langchain.chat_models.azureml_endpoint.LLamaContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.LLamaChatContentFormatter`.*\n",
|
||||||
|
"\n",
|
||||||
|
"You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Examples\n",
|
||||||
|
"\n",
|
||||||
|
"The following section cotain examples about how to use this class:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.schema import HumanMessage\n",
|
||||||
|
"from langchain_community.chat_models.azureml_endpoint import (\n",
|
||||||
|
" AzureMLEndpointApiType,\n",
|
||||||
|
" LlamaChatContentFormatter,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Example: Chat completions with real-time endpoints"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -76,11 +110,79 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"chat = AzureMLChatOnlineEndpoint(\n",
|
"chat = AzureMLChatOnlineEndpoint(\n",
|
||||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
||||||
|
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
|
||||||
" endpoint_api_key=\"my-api-key\",\n",
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
" content_formatter=LlamaContentFormatter,\n",
|
" content_formatter=LlamaChatContentFormatter(),\n",
|
||||||
")\n",
|
")\n",
|
||||||
"response = chat(\n",
|
"response = chat.invoke(\n",
|
||||||
" messages=[HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Example: Chat completions with pay-as-you-go deployments (model as a service)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = AzureMLChatOnlineEndpoint(\n",
|
||||||
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
||||||
|
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||||
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
|
" content_formatter=LlamaChatContentFormatter,\n",
|
||||||
|
")\n",
|
||||||
|
"response = chat.invoke(\n",
|
||||||
|
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"If you need to pass additional parameters to the model, use `model_kwards` argument:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = AzureMLChatOnlineEndpoint(\n",
|
||||||
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
||||||
|
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||||
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
|
" content_formatter=LlamaChatContentFormatter,\n",
|
||||||
|
" model_kwargs={\"temperature\": 0.8},\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Parameters can also be passed during invocation:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = chat.invoke(\n",
|
||||||
|
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")],\n",
|
||||||
|
" max_tokens=512,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"response"
|
"response"
|
||||||
]
|
]
|
||||||
|
@ -6,9 +6,9 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Azure ML\n",
|
"# Azure ML\n",
|
||||||
"\n",
|
"\n",
|
||||||
"[Azure ML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides Azure Foundation Models and OpenAI Models. Azure Foundation Models include various open-source models and popular Hugging Face models. Users can also import models of their liking into AzureML.\n",
|
"[Azure ML](https://azure.microsoft.com/en-us/products/machine-learning/) is a platform used to build, train, and deploy machine learning models. Users can explore the types of models to deploy in the Model Catalog, which provides foundational and general purpose models from different providers.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook goes over how to use an LLM hosted on an `AzureML online endpoint`"
|
"This notebook goes over how to use an LLM hosted on an `Azure ML Online Endpoint`."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -26,11 +26,12 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Set up\n",
|
"## Set up\n",
|
||||||
"\n",
|
"\n",
|
||||||
"To use the wrapper, you must [deploy a model on AzureML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) and obtain the following parameters:\n",
|
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `endpoint_api_key`: Required - The API key provided by the endpoint\n",
|
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
||||||
"* `endpoint_url`: Required - The REST endpoint url provided by the endpoint\n",
|
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
||||||
"* `deployment_name`: Not required - The deployment name of the model using the endpoint"
|
"* `endpoint_api_key`: The API key provided by the endpoint.\n",
|
||||||
|
"* `deployment_name`: (Optional) The deployment name of the model using the endpoint."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -46,31 +47,107 @@
|
|||||||
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
|
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
|
||||||
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n",
|
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n",
|
||||||
"\n",
|
"\n",
|
||||||
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*\n",
|
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*"
|
||||||
"\n",
|
|
||||||
"Below is an example using a summarization model from Hugging Face."
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Custom Content Formatter"
|
"## Examples"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Example: LlaMa 2 completions with real-time endpoints"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
"source": [
|
||||||
"name": "stdout",
|
"from langchain.schema import HumanMessage\n",
|
||||||
"output_type": "stream",
|
"from langchain_community.llms.azureml_endpoint import (\n",
|
||||||
"text": [
|
" AzureMLEndpointApiType,\n",
|
||||||
"HaSeul won her first music show trophy with \"So What\" on Mnet's M Countdown. Loona released their second EP titled [#] (read as hash] on February 5, 2020. HaSeul did not take part in the promotion of the album because of mental health issues. On October 19, 2020, they released their third EP called [12:00]. It was their first album to enter the Billboard 200, debuting at number 112. On June 2, 2021, the group released their fourth EP called Yummy-Yummy. On August 27, it was announced that they are making their Japanese debut on September 15 under Universal Music Japan sublabel EMI Records.\n"
|
" LlamaContentFormatter,\n",
|
||||||
]
|
")\n",
|
||||||
}
|
"\n",
|
||||||
],
|
"llm = AzureMLOnlineEndpoint(\n",
|
||||||
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
||||||
|
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
|
||||||
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
|
" content_formatter=LlamaContentFormatter(),\n",
|
||||||
|
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||||
|
")\n",
|
||||||
|
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Model parameters can also be indicated during invocation:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = llm.invoke(\"Write me a song about sparkling water:\", temperature=0.5)\n",
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Example: Chat completions with pay-as-you-go deployments (model as a service)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.schema import HumanMessage\n",
|
||||||
|
"from langchain_community.llms.azureml_endpoint import (\n",
|
||||||
|
" AzureMLEndpointApiType,\n",
|
||||||
|
" LlamaContentFormatter,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"llm = AzureMLOnlineEndpoint(\n",
|
||||||
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/completions\",\n",
|
||||||
|
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||||
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
|
" content_formatter=LlamaContentFormatter(),\n",
|
||||||
|
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||||
|
")\n",
|
||||||
|
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
||||||
|
"response"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Example: Custom content formatter\n",
|
||||||
|
"\n",
|
||||||
|
"Below is an example using a summarization model from Hugging Face."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
@ -104,6 +181,7 @@
|
|||||||
"content_formatter = CustomFormatter()\n",
|
"content_formatter = CustomFormatter()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = AzureMLOnlineEndpoint(\n",
|
"llm = AzureMLOnlineEndpoint(\n",
|
||||||
|
" endpoint_api_type=\"realtime\",\n",
|
||||||
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
|
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
|
||||||
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
|
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
|
||||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||||
@ -132,7 +210,7 @@
|
|||||||
"that Loona will release the double A-side single, \"Hula Hoop / Star Seed\" on September 15, with a physical CD release on October \n",
|
"that Loona will release the double A-side single, \"Hula Hoop / Star Seed\" on September 15, with a physical CD release on October \n",
|
||||||
"20.[53] In December, Chuu filed an injunction to suspend her exclusive contract with Blockberry Creative.[54][55]\n",
|
"20.[53] In December, Chuu filed an injunction to suspend her exclusive contract with Blockberry Creative.[54][55]\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"summarized_text = llm(large_text)\n",
|
"summarized_text = llm.invoke(large_text)\n",
|
||||||
"print(summarized_text)"
|
"print(summarized_text)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -140,22 +218,14 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Dolly with LLMChain"
|
"### Example: Dolly with LLMChain"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Many people are willing to talk about themselves; it's others who seem to be stuck up. Try to understand others where they're coming from. Like minded people can build a tribe together.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.chains import LLMChain\n",
|
"from langchain.chains import LLMChain\n",
|
||||||
"from langchain.prompts import PromptTemplate\n",
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
@ -177,31 +247,22 @@
|
|||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
"chain = LLMChain(llm=llm, prompt=prompt)\n",
|
||||||
"print(chain.run({\"word_count\": 100, \"topic\": \"how to make friends\"}))"
|
"print(chain.invoke({\"word_count\": 100, \"topic\": \"how to make friends\"}))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Serializing an LLM\n",
|
"## Serializing an LLM\n",
|
||||||
"You can also save and load LLM configurations"
|
"You can also save and load LLM configurations"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\u001b[1mAzureMLOnlineEndpoint\u001b[0m\n",
|
|
||||||
"Params: {'deployment_name': 'databricks-dolly-v2-12b-4', 'model_kwargs': {'temperature': 0.2, 'max_tokens': 150, 'top_p': 0.8, 'frequency_penalty': 0.32, 'presence_penalty': 0.072}}\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.llms.loading import load_llm\n",
|
"from langchain_community.llms.loading import load_llm\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -224,9 +285,9 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "langchain",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "langchain"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
@ -238,7 +299,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.11.5"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -10,16 +10,24 @@ from langchain_core.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import SecretStr, validator
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|
||||||
|
|
||||||
from langchain_community.llms.azureml_endpoint import (
|
from langchain_community.llms.azureml_endpoint import (
|
||||||
AzureMLEndpointClient,
|
AzureMLBaseEndpoint,
|
||||||
|
AzureMLEndpointApiType,
|
||||||
ContentFormatterBase,
|
ContentFormatterBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaContentFormatter(ContentFormatterBase):
|
class LlamaContentFormatter(ContentFormatterBase):
|
||||||
|
def __init__(self):
|
||||||
|
raise TypeError(
|
||||||
|
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||||
|
"`LlamaChatContentFormatter` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaChatContentFormatter(ContentFormatterBase):
|
||||||
"""Content formatter for `LLaMA`."""
|
"""Content formatter for `LLaMA`."""
|
||||||
|
|
||||||
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
||||||
@ -45,7 +53,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
}
|
}
|
||||||
elif (
|
elif (
|
||||||
isinstance(message, ChatMessage)
|
isinstance(message, ChatMessage)
|
||||||
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
|
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
@ -53,79 +61,96 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
supported = ",".join(
|
supported = ",".join(
|
||||||
[role for role in LlamaContentFormatter.SUPPORTED_ROLES]
|
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""Received unsupported role.
|
f"""Received unsupported role.
|
||||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def _format_request_payload(
|
@property
|
||||||
self, messages: List[BaseMessage], model_kwargs: Dict
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
) -> bytes:
|
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||||
|
|
||||||
|
def format_request_payload(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
model_kwargs: Dict,
|
||||||
|
api_type: AzureMLEndpointApiType,
|
||||||
|
) -> str:
|
||||||
|
"""Formats the request according to the chosen api"""
|
||||||
chat_messages = [
|
chat_messages = [
|
||||||
LlamaContentFormatter._convert_message_to_dict(message)
|
LlamaChatContentFormatter._convert_message_to_dict(message)
|
||||||
for message in messages
|
for message in messages
|
||||||
]
|
]
|
||||||
prompt = json.dumps(
|
if api_type == AzureMLEndpointApiType.realtime:
|
||||||
{"input_data": {"input_string": chat_messages, "parameters": model_kwargs}}
|
request_payload = json.dumps(
|
||||||
)
|
{
|
||||||
return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs)
|
"input_data": {
|
||||||
|
"input_string": chat_messages,
|
||||||
|
"parameters": model_kwargs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif api_type == AzureMLEndpointApiType.serverless:
|
||||||
|
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`api_type` {api_type} is not supported by this formatter"
|
||||||
|
)
|
||||||
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
def format_response_payload(
|
||||||
"""Formats the request according to the chosen api"""
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
return str.encode(prompt)
|
) -> ChatGeneration:
|
||||||
|
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
|
||||||
"""Formats response"""
|
"""Formats response"""
|
||||||
return json.loads(output)["output"]
|
if api_type == AzureMLEndpointApiType.realtime:
|
||||||
|
try:
|
||||||
|
choice = json.loads(output)["output"]
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
|
return ChatGeneration(
|
||||||
|
message=BaseMessage(
|
||||||
|
content=choice.strip(),
|
||||||
|
type="assistant",
|
||||||
|
),
|
||||||
|
generation_info=None,
|
||||||
|
)
|
||||||
|
if api_type == AzureMLEndpointApiType.serverless:
|
||||||
|
try:
|
||||||
|
choice = json.loads(output)["choices"][0]
|
||||||
|
if not isinstance(choice, dict):
|
||||||
|
raise TypeError(
|
||||||
|
"Endpoint response is not well formed for a chat "
|
||||||
|
"model. Expected `dict` but `{type(choice)}` was received."
|
||||||
|
)
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
|
return ChatGeneration(
|
||||||
|
message=BaseMessage(
|
||||||
|
content=choice["message"]["content"].strip(),
|
||||||
|
type=choice["message"]["role"],
|
||||||
|
),
|
||||||
|
generation_info=dict(
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
logprobs=choice.get("logprobs"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||||
|
|
||||||
|
|
||||||
class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
||||||
"""`AzureML` Chat models API.
|
"""Azure ML Online Endpoint chat models.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
azure_llm = AzureMLOnlineEndpoint(
|
||||||
azure_chat = AzureMLChatOnlineEndpoint(
|
|
||||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||||
|
endpoint_api_type=AzureMLApiType.realtime,
|
||||||
endpoint_api_key="my-api-key",
|
endpoint_api_key="my-api-key",
|
||||||
content_formatter=content_formatter,
|
content_formatter=chat_content_formatter,
|
||||||
)
|
)
|
||||||
"""
|
""" # noqa: E501
|
||||||
|
|
||||||
endpoint_url: str = ""
|
|
||||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
|
||||||
env var `AZUREML_ENDPOINT_URL`."""
|
|
||||||
|
|
||||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
|
||||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
|
||||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
|
||||||
|
|
||||||
http_client: Any = None #: :meta private:
|
|
||||||
|
|
||||||
content_formatter: Any = None
|
|
||||||
"""The content formatter that provides an input and output
|
|
||||||
transform function to handle formats between the LLM and
|
|
||||||
the endpoint"""
|
|
||||||
|
|
||||||
model_kwargs: Optional[dict] = None
|
|
||||||
"""Keyword arguments to pass to the model."""
|
|
||||||
|
|
||||||
@validator("http_client", always=True, allow_reuse=True)
|
|
||||||
@classmethod
|
|
||||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
|
||||||
"""Validate that api key and python package exist in environment."""
|
|
||||||
values["endpoint_api_key"] = convert_to_secret_str(
|
|
||||||
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
|
||||||
)
|
|
||||||
endpoint_url = get_from_dict_or_env(
|
|
||||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
|
||||||
)
|
|
||||||
http_client = AzureMLEndpointClient(
|
|
||||||
endpoint_url, values["endpoint_api_key"].get_secret_value()
|
|
||||||
)
|
|
||||||
return http_client
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
@ -140,13 +165,13 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "azureml_chat_endpoint"
|
return "azureml_chat_endpoint"
|
||||||
|
|
||||||
def _call(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> ChatResult:
|
||||||
"""Call out to an AzureML Managed Online endpoint.
|
"""Call out to an AzureML Managed Online endpoint.
|
||||||
Args:
|
Args:
|
||||||
messages: The messages in the conversation with the chat model.
|
messages: The messages in the conversation with the chat model.
|
||||||
@ -158,12 +183,17 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel):
|
|||||||
response = azureml_model("Tell me a joke.")
|
response = azureml_model("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
_model_kwargs.update(kwargs)
|
||||||
|
if stop:
|
||||||
|
_model_kwargs["stop"] = stop
|
||||||
|
|
||||||
request_payload = self.content_formatter._format_request_payload(
|
request_payload = self.content_formatter.format_request_payload(
|
||||||
messages, _model_kwargs
|
messages, _model_kwargs, self.endpoint_api_type
|
||||||
)
|
)
|
||||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
response_payload = self.http_client.call(
|
||||||
generated_text = self.content_formatter.format_response_payload(
|
body=request_payload, run_manager=run_manager
|
||||||
response_payload
|
|
||||||
)
|
)
|
||||||
return generated_text
|
generations = self.content_formatter.format_response_payload(
|
||||||
|
response_payload, self.endpoint_api_type
|
||||||
|
)
|
||||||
|
return ChatResult(generations=[generations])
|
||||||
|
@ -2,12 +2,14 @@ import json
|
|||||||
import urllib.request
|
import urllib.request
|
||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import BaseLLM
|
||||||
from langchain_core.pydantic_v1 import BaseModel, validator
|
from langchain_core.outputs import Generation, LLMResult
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator, validator
|
||||||
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
class AzureMLEndpointClient(object):
|
class AzureMLEndpointClient(object):
|
||||||
@ -26,7 +28,12 @@ class AzureMLEndpointClient(object):
|
|||||||
self.endpoint_api_key = endpoint_api_key
|
self.endpoint_api_key = endpoint_api_key
|
||||||
self.deployment_name = deployment_name
|
self.deployment_name = deployment_name
|
||||||
|
|
||||||
def call(self, body: bytes, **kwargs: Any) -> bytes:
|
def call(
|
||||||
|
self,
|
||||||
|
body: bytes,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> bytes:
|
||||||
"""call."""
|
"""call."""
|
||||||
|
|
||||||
# The azureml-model-deployment header will force the request to go to a
|
# The azureml-model-deployment header will force the request to go to a
|
||||||
@ -45,6 +52,16 @@ class AzureMLEndpointClient(object):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class AzureMLEndpointApiType(str, Enum):
|
||||||
|
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted
|
||||||
|
infrastructure, or `serverless` for models deployed as a service with a
|
||||||
|
pay-as-you-go billing or PTU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
realtime = "realtime"
|
||||||
|
serverless = "serverless"
|
||||||
|
|
||||||
|
|
||||||
class ContentFormatterBase:
|
class ContentFormatterBase:
|
||||||
"""Transform request and response of AzureML endpoint to match with
|
"""Transform request and response of AzureML endpoint to match with
|
||||||
required schema.
|
required schema.
|
||||||
@ -61,7 +78,8 @@ class ContentFormatterBase:
|
|||||||
def format_request_payload(
|
def format_request_payload(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_kwargs: Dict
|
model_kwargs: Dict,
|
||||||
|
api_type: AzureMLEndpointApiType,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
input_str = json.dumps(
|
input_str = json.dumps(
|
||||||
{
|
{
|
||||||
@ -71,7 +89,9 @@ class ContentFormatterBase:
|
|||||||
)
|
)
|
||||||
return str.encode(input_str)
|
return str.encode(input_str)
|
||||||
|
|
||||||
def format_response_payload(self, output: str) -> str:
|
def format_response_payload(
|
||||||
|
self, output: str, api_type: AzureMLEndpointApiType
|
||||||
|
) -> str:
|
||||||
response_json = json.loads(output)
|
response_json = json.loads(output)
|
||||||
return response_json[0]["0"]
|
return response_json[0]["0"]
|
||||||
"""
|
"""
|
||||||
@ -81,6 +101,12 @@ class ContentFormatterBase:
|
|||||||
accepts: Optional[str] = "application/json"
|
accepts: Optional[str] = "application/json"
|
||||||
"""The MIME type of the response data returned from the endpoint"""
|
"""The MIME type of the response data returned from the endpoint"""
|
||||||
|
|
||||||
|
format_error_msg: Optional[str] = (
|
||||||
|
"Error while formatting response payload for chat model of type "
|
||||||
|
" `{api_type}`. Are you using the right formatter for the deployed "
|
||||||
|
" model and endpoint type?"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def escape_special_characters(prompt: str) -> str:
|
def escape_special_characters(prompt: str) -> str:
|
||||||
"""Escapes any special characters in `prompt`"""
|
"""Escapes any special characters in `prompt`"""
|
||||||
@ -100,15 +126,32 @@ class ContentFormatterBase:
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
|
"""Supported APIs for the given formatter. Azure ML supports
|
||||||
|
deploying models using different hosting methods. Each method may have
|
||||||
|
a different API structure."""
|
||||||
|
|
||||||
|
return [AzureMLEndpointApiType.realtime]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
def format_request_payload(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_kwargs: Dict,
|
||||||
|
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||||
|
) -> bytes:
|
||||||
"""Formats the request body according to the input schema of
|
"""Formats the request body according to the input schema of
|
||||||
the model. Returns bytes or seekable file like object in the
|
the model. Returns bytes or seekable file like object in the
|
||||||
format specified in the content_type request header.
|
format specified in the content_type request header.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
def format_response_payload(
|
||||||
|
self,
|
||||||
|
output: bytes,
|
||||||
|
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
||||||
|
) -> Generation:
|
||||||
"""Formats the response body according to the output
|
"""Formats the response body according to the output
|
||||||
schema of the model. Returns the data type that is
|
schema of the model. Returns the data type that is
|
||||||
received from the response.
|
received from the response.
|
||||||
@ -118,15 +161,27 @@ class ContentFormatterBase:
|
|||||||
class GPT2ContentFormatter(ContentFormatterBase):
|
class GPT2ContentFormatter(ContentFormatterBase):
|
||||||
"""Content handler for GPT2"""
|
"""Content handler for GPT2"""
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
@property
|
||||||
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
|
return [AzureMLEndpointApiType.realtime]
|
||||||
|
|
||||||
|
def format_request_payload(
|
||||||
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
|
) -> bytes:
|
||||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||||
request_payload = json.dumps(
|
request_payload = json.dumps(
|
||||||
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
|
{"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs}
|
||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
def format_response_payload(
|
||||||
return json.loads(output)[0]["0"]
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
|
) -> Generation:
|
||||||
|
try:
|
||||||
|
choice = json.loads(output)[0]["0"]
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
class OSSContentFormatter(GPT2ContentFormatter):
|
class OSSContentFormatter(GPT2ContentFormatter):
|
||||||
@ -148,21 +203,39 @@ class OSSContentFormatter(GPT2ContentFormatter):
|
|||||||
class HFContentFormatter(ContentFormatterBase):
|
class HFContentFormatter(ContentFormatterBase):
|
||||||
"""Content handler for LLMs from the HuggingFace catalog."""
|
"""Content handler for LLMs from the HuggingFace catalog."""
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
@property
|
||||||
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
|
return [AzureMLEndpointApiType.realtime]
|
||||||
|
|
||||||
|
def format_request_payload(
|
||||||
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
|
) -> bytes:
|
||||||
ContentFormatterBase.escape_special_characters(prompt)
|
ContentFormatterBase.escape_special_characters(prompt)
|
||||||
request_payload = json.dumps(
|
request_payload = json.dumps(
|
||||||
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
|
{"inputs": [f'"{prompt}"'], "parameters": model_kwargs}
|
||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
def format_response_payload(
|
||||||
return json.loads(output)[0]["generated_text"]
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
|
) -> Generation:
|
||||||
|
try:
|
||||||
|
choice = json.loads(output)[0]["0"]["generated_text"]
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
class DollyContentFormatter(ContentFormatterBase):
|
class DollyContentFormatter(ContentFormatterBase):
|
||||||
"""Content handler for the Dolly-v2-12b model"""
|
"""Content handler for the Dolly-v2-12b model"""
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
@property
|
||||||
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
|
return [AzureMLEndpointApiType.realtime]
|
||||||
|
|
||||||
|
def format_request_payload(
|
||||||
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
|
) -> bytes:
|
||||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||||
request_payload = json.dumps(
|
request_payload = json.dumps(
|
||||||
{
|
{
|
||||||
@ -172,49 +245,88 @@ class DollyContentFormatter(ContentFormatterBase):
|
|||||||
)
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
def format_response_payload(
|
||||||
return json.loads(output)[0]
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
|
) -> Generation:
|
||||||
|
try:
|
||||||
|
choice = json.loads(output)[0]
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
class LlamaContentFormatter(ContentFormatterBase):
|
class LlamaContentFormatter(ContentFormatterBase):
|
||||||
"""Content formatter for LLaMa"""
|
"""Content formatter for LLaMa"""
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
|
@property
|
||||||
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
|
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||||
|
|
||||||
|
def format_request_payload(
|
||||||
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
|
) -> bytes:
|
||||||
"""Formats the request according to the chosen api"""
|
"""Formats the request according to the chosen api"""
|
||||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||||
request_payload = json.dumps(
|
if api_type == AzureMLEndpointApiType.realtime:
|
||||||
{
|
request_payload = json.dumps(
|
||||||
"input_data": {
|
{
|
||||||
"input_string": [f'"{prompt}"'],
|
"input_data": {
|
||||||
"parameters": model_kwargs,
|
"input_string": [f'"{prompt}"'],
|
||||||
|
"parameters": model_kwargs,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
)
|
||||||
)
|
elif api_type == AzureMLEndpointApiType.serverless:
|
||||||
|
request_payload = json.dumps({"prompt": prompt, **model_kwargs})
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`api_type` {api_type} is not supported by this formatter"
|
||||||
|
)
|
||||||
return str.encode(request_payload)
|
return str.encode(request_payload)
|
||||||
|
|
||||||
def format_response_payload(self, output: bytes) -> str:
|
def format_response_payload(
|
||||||
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
|
) -> Generation:
|
||||||
"""Formats response"""
|
"""Formats response"""
|
||||||
return json.loads(output)[0]["0"]
|
if api_type == AzureMLEndpointApiType.realtime:
|
||||||
|
try:
|
||||||
|
choice = json.loads(output)[0]["0"]
|
||||||
class AzureMLOnlineEndpoint(LLM, BaseModel):
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
"""Azure ML Online Endpoint models.
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
|
return Generation(text=choice)
|
||||||
Example:
|
if api_type == AzureMLEndpointApiType.serverless:
|
||||||
.. code-block:: python
|
try:
|
||||||
|
choice = json.loads(output)["choices"][0]
|
||||||
azure_llm = AzureMLOnlineEndpoint(
|
if not isinstance(choice, dict):
|
||||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
raise TypeError(
|
||||||
endpoint_api_key="my-api-key",
|
"Endpoint response is not well formed for a chat "
|
||||||
content_formatter=content_formatter,
|
"model. Expected `dict` but `{type(choice)}` was "
|
||||||
|
"received."
|
||||||
|
)
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||||
|
return Generation(
|
||||||
|
text=choice["text"].strip(),
|
||||||
|
generation_info=dict(
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
logprobs=choice.get("logprobs"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
""" # noqa: E501
|
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||||
|
|
||||||
|
|
||||||
|
class AzureMLBaseEndpoint(BaseModel):
|
||||||
|
"""Azure ML Online Endpoint models."""
|
||||||
|
|
||||||
endpoint_url: str = ""
|
endpoint_url: str = ""
|
||||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||||
env var `AZUREML_ENDPOINT_URL`."""
|
env var `AZUREML_ENDPOINT_URL`."""
|
||||||
|
|
||||||
endpoint_api_key: str = ""
|
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime
|
||||||
|
"""Type of the endpoint being consumed. Possible values are `serverless` for
|
||||||
|
pay-as-you-go and `realtime` for real-time endpoints. """
|
||||||
|
|
||||||
|
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||||
env var `AZUREML_ENDPOINT_API_KEY`."""
|
env var `AZUREML_ENDPOINT_API_KEY`."""
|
||||||
|
|
||||||
@ -232,22 +344,106 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
|||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
"""Keyword arguments to pass to the model."""
|
"""Keyword arguments to pass to the model."""
|
||||||
|
|
||||||
@validator("http_client", always=True, allow_reuse=True)
|
@root_validator(pre=True)
|
||||||
@classmethod
|
def validate_environ(cls, values: Dict) -> Dict:
|
||||||
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
values["endpoint_api_key"] = convert_to_secret_str(
|
||||||
"""Validate that api key and python package exists in environment."""
|
get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY")
|
||||||
endpoint_key = get_from_dict_or_env(
|
|
||||||
values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY"
|
|
||||||
)
|
)
|
||||||
endpoint_url = get_from_dict_or_env(
|
values["endpoint_url"] = get_from_dict_or_env(
|
||||||
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
values, "endpoint_url", "AZUREML_ENDPOINT_URL"
|
||||||
)
|
)
|
||||||
deployment_name = get_from_dict_or_env(
|
values["deployment_name"] = get_from_dict_or_env(
|
||||||
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
|
values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", ""
|
||||||
)
|
)
|
||||||
http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name)
|
values["endpoint_api_type"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"endpoint_api_type",
|
||||||
|
"AZUREML_ENDPOINT_API_TYPE",
|
||||||
|
AzureMLEndpointApiType.realtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@validator("content_formatter")
|
||||||
|
def validate_content_formatter(
|
||||||
|
cls, field_value: Any, values: Dict
|
||||||
|
) -> ContentFormatterBase:
|
||||||
|
"""Validate that content formatter is supported by endpoint type."""
|
||||||
|
endpoint_api_type = values.get("endpoint_api_type")
|
||||||
|
if endpoint_api_type not in field_value.supported_api_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"Content formatter f{type(field_value)} is not supported by this "
|
||||||
|
f"endpoint. Supported types are {field_value.supported_api_types} "
|
||||||
|
f"but endpoint is {endpoint_api_type}."
|
||||||
|
)
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
@validator("endpoint_url")
|
||||||
|
def validate_endpoint_url(cls, field_value: Any) -> str:
|
||||||
|
"""Validate that endpoint url is complete."""
|
||||||
|
if field_value.endswith("/"):
|
||||||
|
field_value = field_value[:-1]
|
||||||
|
if field_value.endswith("inference.ml.azure.com"):
|
||||||
|
raise ValueError(
|
||||||
|
"`endpoint_url` should contain the full invocation URL including "
|
||||||
|
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` "
|
||||||
|
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
|
||||||
|
)
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
@validator("endpoint_api_type")
|
||||||
|
def validate_endpoint_api_type(
|
||||||
|
cls, field_value: Any, values: Dict
|
||||||
|
) -> AzureMLEndpointApiType:
|
||||||
|
"""Validate that endpoint api type is compatible with the URL format."""
|
||||||
|
endpoint_url = values.get("endpoint_url")
|
||||||
|
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
|
||||||
|
"/score"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Endpoints of type `realtime` should follow the format "
|
||||||
|
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
|
||||||
|
" If your endpoint URL ends with `/v1/completions` or"
|
||||||
|
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
||||||
|
)
|
||||||
|
if field_value == AzureMLEndpointApiType.serverless and not (
|
||||||
|
endpoint_url.endswith("/v1/completions")
|
||||||
|
or endpoint_url.endswith("/v1/chat/completions")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Endpoints of type `serverless` should follow the format "
|
||||||
|
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
|
||||||
|
" or `https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions`"
|
||||||
|
)
|
||||||
|
|
||||||
|
return field_value
|
||||||
|
|
||||||
|
@validator("http_client", always=True)
|
||||||
|
def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
endpoint_url = values.get("endpoint_url")
|
||||||
|
endpoint_key = values.get("endpoint_api_key")
|
||||||
|
deployment_name = values.get("deployment_name")
|
||||||
|
|
||||||
|
http_client = AzureMLEndpointClient(
|
||||||
|
endpoint_url, endpoint_key.get_secret_value(), deployment_name
|
||||||
|
)
|
||||||
return http_client
|
return http_client
|
||||||
|
|
||||||
|
|
||||||
|
class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
|
||||||
|
"""Azure ML Online Endpoint models.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
azure_llm = AzureMLOnlineEndpoint(
|
||||||
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||||
|
endpoint_api_type=AzureMLApiType.realtime,
|
||||||
|
endpoint_api_key="my-api-key",
|
||||||
|
content_formatter=content_formatter,
|
||||||
|
)
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
@ -262,16 +458,17 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "azureml_endpoint"
|
return "azureml_endpoint"
|
||||||
|
|
||||||
def _call(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompts: List[str],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> LLMResult:
|
||||||
"""Call out to an AzureML Managed Online endpoint.
|
"""Run the LLM on the given prompts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompt to pass into the model.
|
prompts: The prompt to pass into the model.
|
||||||
stop: Optional list of stop words to use when generating.
|
stop: Optional list of stop words to use when generating.
|
||||||
Returns:
|
Returns:
|
||||||
The string generated by the model.
|
The string generated by the model.
|
||||||
@ -280,12 +477,21 @@ class AzureMLOnlineEndpoint(LLM, BaseModel):
|
|||||||
response = azureml_model("Tell me a joke.")
|
response = azureml_model("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
_model_kwargs = self.model_kwargs or {}
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
_model_kwargs.update(kwargs)
|
||||||
|
if stop:
|
||||||
|
_model_kwargs["stop"] = stop
|
||||||
|
generations = []
|
||||||
|
|
||||||
request_payload = self.content_formatter.format_request_payload(
|
for prompt in prompts:
|
||||||
prompt, _model_kwargs
|
request_payload = self.content_formatter.format_request_payload(
|
||||||
)
|
prompt, _model_kwargs, self.endpoint_api_type
|
||||||
response_payload = self.http_client.call(request_payload, **kwargs)
|
)
|
||||||
generated_text = self.content_formatter.format_response_payload(
|
response_payload = self.http_client.call(
|
||||||
response_payload
|
body=request_payload, run_manager=run_manager
|
||||||
)
|
)
|
||||||
return generated_text
|
generated_text = self.content_formatter.format_response_payload(
|
||||||
|
response_payload, self.endpoint_api_type
|
||||||
|
)
|
||||||
|
generations.append([generated_text])
|
||||||
|
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
@ -5,31 +5,31 @@ from langchain_core.outputs import ChatGeneration, LLMResult
|
|||||||
|
|
||||||
from langchain_community.chat_models.azureml_endpoint import (
|
from langchain_community.chat_models.azureml_endpoint import (
|
||||||
AzureMLChatOnlineEndpoint,
|
AzureMLChatOnlineEndpoint,
|
||||||
LlamaContentFormatter,
|
LlamaChatContentFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_llama_call() -> None:
|
def test_llama_call() -> None:
|
||||||
"""Test valid call to Open Source Foundation Model."""
|
"""Test valid call to Open Source Foundation Model."""
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||||
response = chat(messages=[HumanMessage(content="Foo")])
|
response = chat.invoke([HumanMessage(content="Foo")])
|
||||||
assert isinstance(response, BaseMessage)
|
assert isinstance(response, BaseMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
def test_timeout_kwargs() -> None:
|
def test_temperature_kwargs() -> None:
|
||||||
"""Test that timeout kwarg works."""
|
"""Test that timeout kwarg works."""
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||||
response = chat(messages=[HumanMessage(content="FOO")], timeout=60)
|
response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8)
|
||||||
assert isinstance(response, BaseMessage)
|
assert isinstance(response, BaseMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
def test_message_history() -> None:
|
def test_message_history() -> None:
|
||||||
"""Test that multiple messages works."""
|
"""Test that multiple messages works."""
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||||
response = chat(
|
response = chat.invoke(
|
||||||
messages=[
|
[
|
||||||
HumanMessage(content="Hello."),
|
HumanMessage(content="Hello."),
|
||||||
AIMessage(content="Hello!"),
|
AIMessage(content="Hello!"),
|
||||||
HumanMessage(content="How are you doing?"),
|
HumanMessage(content="How are you doing?"),
|
||||||
@ -40,7 +40,7 @@ def test_message_history() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_multiple_messages() -> None:
|
def test_multiple_messages() -> None:
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
||||||
message = HumanMessage(content="Hi!")
|
message = HumanMessage(content="Hi!")
|
||||||
response = chat.generate([[message], [message]])
|
response = chat.generate([[message], [message]])
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from typing import Dict
|
|||||||
from urllib.request import HTTPError
|
from urllib.request import HTTPError
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.pydantic_v1 import ValidationError
|
||||||
|
|
||||||
from langchain_community.llms.azureml_endpoint import (
|
from langchain_community.llms.azureml_endpoint import (
|
||||||
AzureMLOnlineEndpoint,
|
AzureMLOnlineEndpoint,
|
||||||
@ -26,7 +27,7 @@ def test_gpt2_call() -> None:
|
|||||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||||
content_formatter=OSSContentFormatter(),
|
content_formatter=OSSContentFormatter(),
|
||||||
)
|
)
|
||||||
output = llm("Foo")
|
output = llm.invoke("Foo")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ def test_hf_call() -> None:
|
|||||||
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
|
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
|
||||||
content_formatter=HFContentFormatter(),
|
content_formatter=HFContentFormatter(),
|
||||||
)
|
)
|
||||||
output = llm("Foo")
|
output = llm.invoke("Foo")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ def test_dolly_call() -> None:
|
|||||||
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
|
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
|
||||||
content_formatter=DollyContentFormatter(),
|
content_formatter=DollyContentFormatter(),
|
||||||
)
|
)
|
||||||
output = llm("Foo")
|
output = llm.invoke("Foo")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ def test_custom_formatter() -> None:
|
|||||||
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
|
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
|
||||||
content_formatter=CustomFormatter(),
|
content_formatter=CustomFormatter(),
|
||||||
)
|
)
|
||||||
output = llm("Foo")
|
output = llm.invoke("Foo")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ def test_missing_content_formatter() -> None:
|
|||||||
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||||
)
|
)
|
||||||
llm("Foo")
|
llm.invoke("Foo")
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_request_format() -> None:
|
def test_invalid_request_format() -> None:
|
||||||
@ -123,7 +124,31 @@ def test_invalid_request_format() -> None:
|
|||||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||||
content_formatter=CustomContentFormatter(),
|
content_formatter=CustomContentFormatter(),
|
||||||
)
|
)
|
||||||
llm("Foo")
|
llm.invoke("Foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_incorrect_url() -> None:
|
||||||
|
"""Testing AzureML Endpoint for an incorrect URL"""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
llm = AzureMLOnlineEndpoint(
|
||||||
|
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||||
|
endpoint_url="https://endpoint.inference.com",
|
||||||
|
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||||
|
content_formatter=OSSContentFormatter(),
|
||||||
|
)
|
||||||
|
llm.invoke("Foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_incorrect_api_type() -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
llm = AzureMLOnlineEndpoint(
|
||||||
|
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
|
||||||
|
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
|
||||||
|
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||||
|
endpoint_api_type="serverless",
|
||||||
|
content_formatter=OSSContentFormatter(),
|
||||||
|
)
|
||||||
|
llm.invoke("Foo")
|
||||||
|
|
||||||
|
|
||||||
def test_incorrect_key() -> None:
|
def test_incorrect_key() -> None:
|
||||||
@ -135,7 +160,7 @@ def test_incorrect_key() -> None:
|
|||||||
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
|
||||||
content_formatter=OSSContentFormatter(),
|
content_formatter=OSSContentFormatter(),
|
||||||
)
|
)
|
||||||
llm("Foo")
|
llm.invoke("Foo")
|
||||||
|
|
||||||
|
|
||||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user