mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
community[patch]: Support Streaming in Azure Machine Learning (#18246)
- [x] **PR title**: "community: Support streaming in Azure ML and few naming changes" - [x] **PR message**: - **Description:** Added support for streaming for azureml_endpoint. Also, renamed and AzureMLEndpointApiType.realtime to AzureMLEndpointApiType.dedicated. Also, added new classes CustomOpenAIChatContentFormatter and CustomOpenAIContentFormatter and updated the classes LlamaChatContentFormatter and LlamaContentFormatter to now show a deprecated warning message when instantiated. --------- Co-authored-by: Sachin Paryani <saparan@microsoft.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ecb11a4a32
commit
25c9f3d1d1
@ -40,7 +40,7 @@
|
|||||||
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
||||||
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
"* `endpoint_api_type`: Use `endpoint_type='dedicated'` when deploying models to **Dedicated endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
||||||
"* `endpoint_api_key`: The API key provided by the endpoint"
|
"* `endpoint_api_key`: The API key provided by the endpoint"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -52,9 +52,9 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
|
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `LLamaChatContentFormatter`: Formats request and response data for LLaMa2-chat\n",
|
"* `CustomOpenAIChatContentFormatter`: Formats request and response data for models like LLaMa2-chat that follow the OpenAI API spec for request and response.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"*Note: `langchain.chat_models.azureml_endpoint.LLamaContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.LLamaChatContentFormatter`.*\n",
|
"*Note: `langchain.chat_models.azureml_endpoint.LlamaChatContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.CustomOpenAIChatContentFormatter`.*\n",
|
||||||
"\n",
|
"\n",
|
||||||
"You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`."
|
"You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`."
|
||||||
]
|
]
|
||||||
@ -65,20 +65,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"## Examples\n",
|
"## Examples\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The following section cotain examples about how to use this class:"
|
"The following section contains examples about how to use this class:"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from langchain_community.chat_models.azureml_endpoint import (\n",
|
|
||||||
" AzureMLEndpointApiType,\n",
|
|
||||||
" LlamaChatContentFormatter,\n",
|
|
||||||
")\n",
|
|
||||||
"from langchain_core.messages import HumanMessage"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -105,14 +92,17 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.chat_models.azureml_endpoint import LlamaContentFormatter\n",
|
"from langchain_community.chat_models.azureml_endpoint import (\n",
|
||||||
|
" AzureMLEndpointApiType,\n",
|
||||||
|
" CustomOpenAIChatContentFormatter,\n",
|
||||||
|
")\n",
|
||||||
"from langchain_core.messages import HumanMessage\n",
|
"from langchain_core.messages import HumanMessage\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = AzureMLChatOnlineEndpoint(\n",
|
"chat = AzureMLChatOnlineEndpoint(\n",
|
||||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
||||||
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
|
" endpoint_api_type=AzureMLEndpointApiType.dedicated,\n",
|
||||||
" endpoint_api_key=\"my-api-key\",\n",
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
" content_formatter=LlamaChatContentFormatter(),\n",
|
" content_formatter=CustomOpenAIChatContentFormatter(),\n",
|
||||||
")\n",
|
")\n",
|
||||||
"response = chat.invoke(\n",
|
"response = chat.invoke(\n",
|
||||||
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
||||||
@ -137,7 +127,7 @@
|
|||||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
||||||
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||||
" endpoint_api_key=\"my-api-key\",\n",
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
" content_formatter=LlamaChatContentFormatter,\n",
|
" content_formatter=CustomOpenAIChatContentFormatter,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"response = chat.invoke(\n",
|
"response = chat.invoke(\n",
|
||||||
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
|
||||||
@ -149,7 +139,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"If you need to pass additional parameters to the model, use `model_kwards` argument:"
|
"If you need to pass additional parameters to the model, use `model_kwargs` argument:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -162,7 +152,7 @@
|
|||||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
|
||||||
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||||
" endpoint_api_key=\"my-api-key\",\n",
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
" content_formatter=LlamaChatContentFormatter,\n",
|
" content_formatter=CustomOpenAIChatContentFormatter,\n",
|
||||||
" model_kwargs={\"temperature\": 0.8},\n",
|
" model_kwargs={\"temperature\": 0.8},\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
@ -204,7 +194,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -29,7 +29,7 @@
|
|||||||
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
|
||||||
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
"* `endpoint_api_type`: Use `endpoint_type='dedicated'` when deploying models to **Dedicated endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
|
||||||
"* `endpoint_api_key`: The API key provided by the endpoint.\n",
|
"* `endpoint_api_key`: The API key provided by the endpoint.\n",
|
||||||
"* `deployment_name`: (Optional) The deployment name of the model using the endpoint."
|
"* `deployment_name`: (Optional) The deployment name of the model using the endpoint."
|
||||||
]
|
]
|
||||||
@ -45,7 +45,7 @@
|
|||||||
"* `GPT2ContentFormatter`: Formats request and response data for GPT2\n",
|
"* `GPT2ContentFormatter`: Formats request and response data for GPT2\n",
|
||||||
"* `DollyContentFormatter`: Formats request and response data for the Dolly-v2\n",
|
"* `DollyContentFormatter`: Formats request and response data for the Dolly-v2\n",
|
||||||
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
|
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
|
||||||
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n",
|
"* `CustomOpenAIContentFormatter`: Formats request and response data for models like LLaMa2 that follow OpenAI API compatible scheme.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*"
|
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*"
|
||||||
]
|
]
|
||||||
@ -72,15 +72,15 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.llms.azureml_endpoint import (\n",
|
"from langchain_community.llms.azureml_endpoint import (\n",
|
||||||
" AzureMLEndpointApiType,\n",
|
" AzureMLEndpointApiType,\n",
|
||||||
" LlamaContentFormatter,\n",
|
" CustomOpenAIContentFormatter,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"from langchain_core.messages import HumanMessage\n",
|
"from langchain_core.messages import HumanMessage\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = AzureMLOnlineEndpoint(\n",
|
"llm = AzureMLOnlineEndpoint(\n",
|
||||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
|
||||||
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n",
|
" endpoint_api_type=AzureMLEndpointApiType.dedicated,\n",
|
||||||
" endpoint_api_key=\"my-api-key\",\n",
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
" content_formatter=LlamaContentFormatter(),\n",
|
" content_formatter=CustomOpenAIContentFormatter(),\n",
|
||||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||||
")\n",
|
")\n",
|
||||||
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
||||||
@ -119,7 +119,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from langchain_community.llms.azureml_endpoint import (\n",
|
"from langchain_community.llms.azureml_endpoint import (\n",
|
||||||
" AzureMLEndpointApiType,\n",
|
" AzureMLEndpointApiType,\n",
|
||||||
" LlamaContentFormatter,\n",
|
" CustomOpenAIContentFormatter,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"from langchain_core.messages import HumanMessage\n",
|
"from langchain_core.messages import HumanMessage\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -127,7 +127,7 @@
|
|||||||
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/completions\",\n",
|
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/completions\",\n",
|
||||||
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
|
||||||
" endpoint_api_key=\"my-api-key\",\n",
|
" endpoint_api_key=\"my-api-key\",\n",
|
||||||
" content_formatter=LlamaContentFormatter(),\n",
|
" content_formatter=CustomOpenAIContentFormatter(),\n",
|
||||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||||
")\n",
|
")\n",
|
||||||
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
"response = llm.invoke(\"Write me a song about sparkling water:\")\n",
|
||||||
@ -181,7 +181,7 @@
|
|||||||
"content_formatter = CustomFormatter()\n",
|
"content_formatter = CustomFormatter()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = AzureMLOnlineEndpoint(\n",
|
"llm = AzureMLOnlineEndpoint(\n",
|
||||||
" endpoint_api_type=\"realtime\",\n",
|
" endpoint_api_type=\"dedicated\",\n",
|
||||||
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
|
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
|
||||||
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
|
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
|
||||||
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
|
||||||
|
@ -1,16 +1,37 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional, cast
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
from langchain_community.llms.azureml_endpoint import (
|
from langchain_community.llms.azureml_endpoint import (
|
||||||
AzureMLBaseEndpoint,
|
AzureMLBaseEndpoint,
|
||||||
@ -25,12 +46,12 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
"`LlamaContentFormatter` is deprecated for chat models. Use "
|
||||||
"`LlamaChatContentFormatter` instead."
|
"`CustomOpenAIContentFormatter` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaChatContentFormatter(ContentFormatterBase):
|
class CustomOpenAIChatContentFormatter(ContentFormatterBase):
|
||||||
"""Content formatter for `LLaMA`."""
|
"""Chat Content formatter for models with OpenAI like API scheme."""
|
||||||
|
|
||||||
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
|
||||||
|
|
||||||
@ -55,7 +76,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
}
|
}
|
||||||
elif (
|
elif (
|
||||||
isinstance(message, ChatMessage)
|
isinstance(message, ChatMessage)
|
||||||
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES
|
and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
@ -63,7 +84,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
supported = ",".join(
|
supported = ",".join(
|
||||||
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES]
|
[role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES]
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"""Received unsupported role.
|
f"""Received unsupported role.
|
||||||
@ -72,7 +93,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
|
||||||
|
|
||||||
def format_messages_request_payload(
|
def format_messages_request_payload(
|
||||||
self,
|
self,
|
||||||
@ -82,10 +103,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Formats the request according to the chosen api"""
|
"""Formats the request according to the chosen api"""
|
||||||
chat_messages = [
|
chat_messages = [
|
||||||
LlamaChatContentFormatter._convert_message_to_dict(message)
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(message)
|
||||||
for message in messages
|
for message in messages
|
||||||
]
|
]
|
||||||
if api_type == AzureMLEndpointApiType.realtime:
|
if api_type in [
|
||||||
|
AzureMLEndpointApiType.dedicated,
|
||||||
|
AzureMLEndpointApiType.realtime,
|
||||||
|
]:
|
||||||
request_payload = json.dumps(
|
request_payload = json.dumps(
|
||||||
{
|
{
|
||||||
"input_data": {
|
"input_data": {
|
||||||
@ -105,10 +129,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
def format_response_payload(
|
def format_response_payload(
|
||||||
self,
|
self,
|
||||||
output: bytes,
|
output: bytes,
|
||||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
||||||
) -> ChatGeneration:
|
) -> ChatGeneration:
|
||||||
"""Formats response"""
|
"""Formats response"""
|
||||||
if api_type == AzureMLEndpointApiType.realtime:
|
if api_type in [
|
||||||
|
AzureMLEndpointApiType.dedicated,
|
||||||
|
AzureMLEndpointApiType.realtime,
|
||||||
|
]:
|
||||||
try:
|
try:
|
||||||
choice = json.loads(output)["output"]
|
choice = json.loads(output)["output"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
@ -143,6 +170,20 @@ class LlamaChatContentFormatter(ContentFormatterBase):
|
|||||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter):
|
||||||
|
"""Deprecated: Kept for backwards compatibility
|
||||||
|
|
||||||
|
Chat Content formatter for Llama."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
warnings.warn(
|
||||||
|
"""`LlamaChatContentFormatter` will be deprecated in the future.
|
||||||
|
Please use `CustomOpenAIChatContentFormatter` instead.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MistralChatContentFormatter(LlamaChatContentFormatter):
|
class MistralChatContentFormatter(LlamaChatContentFormatter):
|
||||||
"""Content formatter for `Mistral`."""
|
"""Content formatter for `Mistral`."""
|
||||||
|
|
||||||
@ -187,8 +228,8 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
azure_llm = AzureMLOnlineEndpoint(
|
azure_llm = AzureMLOnlineEndpoint(
|
||||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions",
|
||||||
endpoint_api_type=AzureMLApiType.realtime,
|
endpoint_api_type=AzureMLApiType.serverless,
|
||||||
endpoint_api_key="my-api-key",
|
endpoint_api_key="my-api-key",
|
||||||
content_formatter=chat_content_formatter,
|
content_formatter=chat_content_formatter,
|
||||||
)
|
)
|
||||||
@ -239,3 +280,143 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
|
|||||||
response_payload, self.endpoint_api_type
|
response_payload, self.endpoint_api_type
|
||||||
)
|
)
|
||||||
return ChatResult(generations=[generations])
|
return ChatResult(generations=[generations])
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
||||||
|
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
client_params = {
|
||||||
|
"api_key": self.endpoint_api_key.get_secret_value(),
|
||||||
|
"base_url": self.endpoint_url,
|
||||||
|
"timeout": timeout,
|
||||||
|
"default_headers": None,
|
||||||
|
"default_query": None,
|
||||||
|
"http_client": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
client = openai.OpenAI(**client_params)
|
||||||
|
message_dicts = [
|
||||||
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
for chunk in client.chat.completions.create(messages=message_dicts, **params):
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.dict()
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
generation_info = {}
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
|
||||||
|
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
client_params = {
|
||||||
|
"api_key": self.endpoint_api_key.get_secret_value(),
|
||||||
|
"base_url": self.endpoint_url,
|
||||||
|
"timeout": timeout,
|
||||||
|
"default_headers": None,
|
||||||
|
"default_query": None,
|
||||||
|
"http_client": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
async_client = openai.AsyncOpenAI(**client_params)
|
||||||
|
message_dicts = [
|
||||||
|
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
params = {"stream": True, "stop": stop, "model": None, **kwargs}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
async for chunk in await async_client.chat.completions.create(
|
||||||
|
messages=message_dicts, **params
|
||||||
|
):
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.dict()
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
generation_info = {}
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = cast(str, _dict.get("role"))
|
||||||
|
content = cast(str, _dict.get("content") or "")
|
||||||
|
additional_kwargs: Dict = {}
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
function_call = dict(_dict["function_call"])
|
||||||
|
if "name" in function_call and function_call["name"] is None:
|
||||||
|
function_call["name"] = ""
|
||||||
|
additional_kwargs["function_call"] = function_call
|
||||||
|
if _dict.get("tool_calls"):
|
||||||
|
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
|
elif role == "tool" or default_class == ToolMessageChunk:
|
||||||
|
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
else:
|
||||||
|
return default_class(content=content)
|
||||||
|
@ -62,12 +62,14 @@ class AzureMLEndpointClient(object):
|
|||||||
|
|
||||||
|
|
||||||
class AzureMLEndpointApiType(str, Enum):
|
class AzureMLEndpointApiType(str, Enum):
|
||||||
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted
|
"""Azure ML endpoints API types. Use `dedicated` for models deployed in hosted
|
||||||
infrastructure, or `serverless` for models deployed as a service with a
|
infrastructure (also known as Online Endpoints in Azure Machine Learning),
|
||||||
|
or `serverless` for models deployed as a service with a
|
||||||
pay-as-you-go billing or PTU.
|
pay-as-you-go billing or PTU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
realtime = "realtime"
|
dedicated = "dedicated"
|
||||||
|
realtime = "realtime" #: Deprecated
|
||||||
serverless = "serverless"
|
serverless = "serverless"
|
||||||
|
|
||||||
|
|
||||||
@ -141,13 +143,13 @@ class ContentFormatterBase:
|
|||||||
deploying models using different hosting methods. Each method may have
|
deploying models using different hosting methods. Each method may have
|
||||||
a different API structure."""
|
a different API structure."""
|
||||||
|
|
||||||
return [AzureMLEndpointApiType.realtime]
|
return [AzureMLEndpointApiType.dedicated]
|
||||||
|
|
||||||
def format_request_payload(
|
def format_request_payload(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_kwargs: Dict,
|
model_kwargs: Dict,
|
||||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Formats the request body according to the input schema of
|
"""Formats the request body according to the input schema of
|
||||||
the model. Returns bytes or seekable file like object in the
|
the model. Returns bytes or seekable file like object in the
|
||||||
@ -159,7 +161,7 @@ class ContentFormatterBase:
|
|||||||
def format_response_payload(
|
def format_response_payload(
|
||||||
self,
|
self,
|
||||||
output: bytes,
|
output: bytes,
|
||||||
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
|
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
|
||||||
) -> Generation:
|
) -> Generation:
|
||||||
"""Formats the response body according to the output
|
"""Formats the response body according to the output
|
||||||
schema of the model. Returns the data type that is
|
schema of the model. Returns the data type that is
|
||||||
@ -172,7 +174,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime]
|
return [AzureMLEndpointApiType.dedicated]
|
||||||
|
|
||||||
def format_request_payload( # type: ignore[override]
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
@ -214,7 +216,7 @@ class HFContentFormatter(ContentFormatterBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime]
|
return [AzureMLEndpointApiType.dedicated]
|
||||||
|
|
||||||
def format_request_payload( # type: ignore[override]
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
@ -240,7 +242,7 @@ class DollyContentFormatter(ContentFormatterBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime]
|
return [AzureMLEndpointApiType.dedicated]
|
||||||
|
|
||||||
def format_request_payload( # type: ignore[override]
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
@ -264,19 +266,22 @@ class DollyContentFormatter(ContentFormatterBase):
|
|||||||
return Generation(text=choice)
|
return Generation(text=choice)
|
||||||
|
|
||||||
|
|
||||||
class LlamaContentFormatter(ContentFormatterBase):
|
class CustomOpenAIContentFormatter(ContentFormatterBase):
|
||||||
"""Content formatter for LLaMa"""
|
"""Content formatter for models that use the OpenAI like API scheme."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
|
||||||
|
|
||||||
def format_request_payload( # type: ignore[override]
|
def format_request_payload( # type: ignore[override]
|
||||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Formats the request according to the chosen api"""
|
"""Formats the request according to the chosen api"""
|
||||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||||
if api_type == AzureMLEndpointApiType.realtime:
|
if api_type in [
|
||||||
|
AzureMLEndpointApiType.dedicated,
|
||||||
|
AzureMLEndpointApiType.realtime,
|
||||||
|
]:
|
||||||
request_payload = json.dumps(
|
request_payload = json.dumps(
|
||||||
{
|
{
|
||||||
"input_data": {
|
"input_data": {
|
||||||
@ -297,7 +302,10 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||||
) -> Generation:
|
) -> Generation:
|
||||||
"""Formats response"""
|
"""Formats response"""
|
||||||
if api_type == AzureMLEndpointApiType.realtime:
|
if api_type in [
|
||||||
|
AzureMLEndpointApiType.dedicated,
|
||||||
|
AzureMLEndpointApiType.realtime,
|
||||||
|
]:
|
||||||
try:
|
try:
|
||||||
choice = json.loads(output)[0]["0"]
|
choice = json.loads(output)[0]["0"]
|
||||||
except (KeyError, IndexError, TypeError) as e:
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
@ -324,6 +332,22 @@ class LlamaContentFormatter(ContentFormatterBase):
|
|||||||
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaContentFormatter(CustomOpenAIContentFormatter):
|
||||||
|
"""Deprecated: Kept for backwards compatibility
|
||||||
|
|
||||||
|
Content formatter for Llama."""
|
||||||
|
|
||||||
|
content_formatter: Any = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
warnings.warn(
|
||||||
|
"""`LlamaContentFormatter` will be deprecated in the future.
|
||||||
|
Please use `CustomOpenAIContentFormatter` instead.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AzureMLBaseEndpoint(BaseModel):
|
class AzureMLBaseEndpoint(BaseModel):
|
||||||
"""Azure ML Online Endpoint models."""
|
"""Azure ML Online Endpoint models."""
|
||||||
|
|
||||||
@ -331,9 +355,9 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as
|
||||||
env var `AZUREML_ENDPOINT_URL`."""
|
env var `AZUREML_ENDPOINT_URL`."""
|
||||||
|
|
||||||
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime
|
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated
|
||||||
"""Type of the endpoint being consumed. Possible values are `serverless` for
|
"""Type of the endpoint being consumed. Possible values are `serverless` for
|
||||||
pay-as-you-go and `realtime` for real-time endpoints. """
|
pay-as-you-go and `dedicated` for dedicated endpoints. """
|
||||||
|
|
||||||
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
endpoint_api_key: SecretStr = convert_to_secret_str("")
|
||||||
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
"""Authentication Key for Endpoint. Should be passed to constructor or specified as
|
||||||
@ -348,6 +372,8 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
|
|
||||||
http_client: Any = None #: :meta private:
|
http_client: Any = None #: :meta private:
|
||||||
|
|
||||||
|
max_retries: int = 1
|
||||||
|
|
||||||
content_formatter: Any = None
|
content_formatter: Any = None
|
||||||
"""The content formatter that provides an input and output
|
"""The content formatter that provides an input and output
|
||||||
transform function to handle formats between the LLM and
|
transform function to handle formats between the LLM and
|
||||||
@ -371,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
values,
|
values,
|
||||||
"endpoint_api_type",
|
"endpoint_api_type",
|
||||||
"AZUREML_ENDPOINT_API_TYPE",
|
"AZUREML_ENDPOINT_API_TYPE",
|
||||||
AzureMLEndpointApiType.realtime,
|
AzureMLEndpointApiType.dedicated,
|
||||||
)
|
)
|
||||||
values["timeout"] = get_from_dict_or_env(
|
values["timeout"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
@ -404,7 +430,7 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
if field_value.endswith("inference.ml.azure.com"):
|
if field_value.endswith("inference.ml.azure.com"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`endpoint_url` should contain the full invocation URL including "
|
"`endpoint_url` should contain the full invocation URL including "
|
||||||
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` "
|
"`/score` for `endpoint_api_type='dedicated'` or `/v1/completions` "
|
||||||
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
|
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
|
||||||
)
|
)
|
||||||
return field_value
|
return field_value
|
||||||
@ -415,11 +441,15 @@ class AzureMLBaseEndpoint(BaseModel):
|
|||||||
) -> AzureMLEndpointApiType:
|
) -> AzureMLEndpointApiType:
|
||||||
"""Validate that endpoint api type is compatible with the URL format."""
|
"""Validate that endpoint api type is compatible with the URL format."""
|
||||||
endpoint_url = values.get("endpoint_url")
|
endpoint_url = values.get("endpoint_url")
|
||||||
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
|
if (
|
||||||
"/score"
|
(
|
||||||
|
field_value == AzureMLEndpointApiType.dedicated
|
||||||
|
or field_value == AzureMLEndpointApiType.realtime
|
||||||
|
)
|
||||||
|
and not endpoint_url.endswith("/score") # type: ignore[union-attr]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Endpoints of type `realtime` should follow the format "
|
"Endpoints of type `dedicated` should follow the format "
|
||||||
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
|
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
|
||||||
" If your endpoint URL ends with `/v1/completions` or"
|
" If your endpoint URL ends with `/v1/completions` or"
|
||||||
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
||||||
@ -461,7 +491,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
azure_llm = AzureMLOnlineEndpoint(
|
azure_llm = AzureMLOnlineEndpoint(
|
||||||
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
|
||||||
endpoint_api_type=AzureMLApiType.realtime,
|
endpoint_api_type=AzureMLApiType.dedicated,
|
||||||
endpoint_api_key="my-api-key",
|
endpoint_api_key="my-api-key",
|
||||||
timeout=120,
|
timeout=120,
|
||||||
content_formatter=content_formatter,
|
content_formatter=content_formatter,
|
||||||
|
@ -5,13 +5,15 @@ from langchain_core.outputs import ChatGeneration, LLMResult
|
|||||||
|
|
||||||
from langchain_community.chat_models.azureml_endpoint import (
|
from langchain_community.chat_models.azureml_endpoint import (
|
||||||
AzureMLChatOnlineEndpoint,
|
AzureMLChatOnlineEndpoint,
|
||||||
LlamaChatContentFormatter,
|
CustomOpenAIChatContentFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_llama_call() -> None:
|
def test_llama_call() -> None:
|
||||||
"""Test valid call to Open Source Foundation Model."""
|
"""Test valid call to Open Source Foundation Model."""
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(
|
||||||
|
content_formatter=CustomOpenAIChatContentFormatter()
|
||||||
|
)
|
||||||
response = chat.invoke([HumanMessage(content="Foo")])
|
response = chat.invoke([HumanMessage(content="Foo")])
|
||||||
assert isinstance(response, BaseMessage)
|
assert isinstance(response, BaseMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
@ -19,7 +21,9 @@ def test_llama_call() -> None:
|
|||||||
|
|
||||||
def test_temperature_kwargs() -> None:
|
def test_temperature_kwargs() -> None:
|
||||||
"""Test that timeout kwarg works."""
|
"""Test that timeout kwarg works."""
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(
|
||||||
|
content_formatter=CustomOpenAIChatContentFormatter()
|
||||||
|
)
|
||||||
response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8)
|
response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8)
|
||||||
assert isinstance(response, BaseMessage)
|
assert isinstance(response, BaseMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
@ -27,7 +31,9 @@ def test_temperature_kwargs() -> None:
|
|||||||
|
|
||||||
def test_message_history() -> None:
|
def test_message_history() -> None:
|
||||||
"""Test that multiple messages works."""
|
"""Test that multiple messages works."""
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(
|
||||||
|
content_formatter=CustomOpenAIChatContentFormatter()
|
||||||
|
)
|
||||||
response = chat.invoke(
|
response = chat.invoke(
|
||||||
[
|
[
|
||||||
HumanMessage(content="Hello."),
|
HumanMessage(content="Hello."),
|
||||||
@ -40,7 +46,9 @@ def test_message_history() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_multiple_messages() -> None:
|
def test_multiple_messages() -> None:
|
||||||
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter())
|
chat = AzureMLChatOnlineEndpoint(
|
||||||
|
content_formatter=CustomOpenAIChatContentFormatter()
|
||||||
|
)
|
||||||
message = HumanMessage(content="Hi!")
|
message = HumanMessage(content="Hi!")
|
||||||
response = chat.generate([[message], [message]])
|
response = chat.generate([[message], [message]])
|
||||||
|
|
||||||
|
@ -2,10 +2,10 @@ from langchain_community.llms.azureml_endpoint import (
|
|||||||
AzureMLEndpointClient,
|
AzureMLEndpointClient,
|
||||||
AzureMLOnlineEndpoint,
|
AzureMLOnlineEndpoint,
|
||||||
ContentFormatterBase,
|
ContentFormatterBase,
|
||||||
|
CustomOpenAIContentFormatter,
|
||||||
DollyContentFormatter,
|
DollyContentFormatter,
|
||||||
GPT2ContentFormatter,
|
GPT2ContentFormatter,
|
||||||
HFContentFormatter,
|
HFContentFormatter,
|
||||||
LlamaContentFormatter,
|
|
||||||
OSSContentFormatter,
|
OSSContentFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -16,6 +16,6 @@ __all__ = [
|
|||||||
"OSSContentFormatter",
|
"OSSContentFormatter",
|
||||||
"HFContentFormatter",
|
"HFContentFormatter",
|
||||||
"DollyContentFormatter",
|
"DollyContentFormatter",
|
||||||
"LlamaContentFormatter",
|
"CustomOpenAIContentFormatter",
|
||||||
"AzureMLOnlineEndpoint",
|
"AzureMLOnlineEndpoint",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user